You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@ariatosca.apache.org by mx...@apache.org on 2016/12/08 09:59:36 UTC
[3/4] incubator-ariatosca git commit: ARIA-30 SQL based storage
implementation
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/aria/storage/models.py
----------------------------------------------------------------------
diff --git a/aria/storage/models.py b/aria/storage/models.py
index d24ad75..6302e66 100644
--- a/aria/storage/models.py
+++ b/aria/storage/models.py
@@ -36,16 +36,30 @@ classes:
* ProviderContext - provider context implementation model.
* Plugin - plugin implementation model.
"""
-
+from collections import namedtuple
from datetime import datetime
-from types import NoneType
-from .structures import Field, IterPointerField, Model, uuid_generator, PointerField
+from sqlalchemy.ext.declarative.base import declared_attr
+
+from .structures import (
+ SQLModelBase,
+ Column,
+ Integer,
+ Text,
+ DateTime,
+ Boolean,
+ Enum,
+ String,
+ Float,
+ List,
+ Dict,
+ foreign_key,
+ one_to_many_relationship,
+ relationship_to_self,
+ orm)
__all__ = (
- 'Model',
'Blueprint',
- 'Snapshot',
'Deployment',
'DeploymentUpdateStep',
'DeploymentUpdate',
@@ -59,66 +73,192 @@ __all__ = (
'Plugin',
)
-# todo: sort this, maybe move from mgr or move from aria???
-ACTION_TYPES = ()
-ENTITY_TYPES = ()
+
+#pylint: disable=no-self-argument
-class Blueprint(Model):
+class Blueprint(SQLModelBase):
"""
- A Model which represents a blueprint
+ Blueprint model representation.
"""
- plan = Field(type=dict)
- id = Field(type=basestring, default=uuid_generator)
- description = Field(type=(basestring, NoneType))
- created_at = Field(type=datetime)
- updated_at = Field(type=datetime)
- main_file_name = Field(type=basestring)
+ __tablename__ = 'blueprints'
+ name = Column(Text, index=True)
+ created_at = Column(DateTime, nullable=False, index=True)
+ main_file_name = Column(Text, nullable=False)
+ plan = Column(Dict, nullable=False)
+ updated_at = Column(DateTime)
+ description = Column(Text)
-class Snapshot(Model):
+
+class Deployment(SQLModelBase):
"""
- A Model which represents a snapshot
+ Deployment model representation.
"""
- CREATED = 'created'
+ __tablename__ = 'deployments'
+
+ _private_fields = ['blueprint_id']
+
+ blueprint_id = foreign_key(Blueprint.id)
+
+ name = Column(Text, index=True)
+ created_at = Column(DateTime, nullable=False, index=True)
+ description = Column(Text)
+ inputs = Column(Dict)
+ groups = Column(Dict)
+ permalink = Column(Text)
+ policy_triggers = Column(Dict)
+ policy_types = Column(Dict)
+ outputs = Column(Dict)
+ scaling_groups = Column(Dict)
+ updated_at = Column(DateTime)
+ workflows = Column(Dict)
+
+ @declared_attr
+ def blueprint(cls):
+ return one_to_many_relationship(cls, Blueprint, cls.blueprint_id)
+
+
+class Execution(SQLModelBase):
+ """
+ Execution model representation.
+ """
+ __tablename__ = 'executions'
+
+ TERMINATED = 'terminated'
FAILED = 'failed'
- CREATING = 'creating'
- UPLOADED = 'uploaded'
- END_STATES = [CREATED, FAILED, UPLOADED]
+ CANCELLED = 'cancelled'
+ PENDING = 'pending'
+ STARTED = 'started'
+ CANCELLING = 'cancelling'
+ FORCE_CANCELLING = 'force_cancelling'
- id = Field(type=basestring, default=uuid_generator)
- created_at = Field(type=datetime)
- status = Field(type=basestring)
- error = Field(type=basestring, default=None)
+ STATES = [TERMINATED, FAILED, CANCELLED, PENDING, STARTED, CANCELLING, FORCE_CANCELLING]
+ END_STATES = [TERMINATED, FAILED, CANCELLED]
+ ACTIVE_STATES = [state for state in STATES if state not in END_STATES]
+ VALID_TRANSITIONS = {
+ PENDING: [STARTED, CANCELLED],
+ STARTED: END_STATES + [CANCELLING],
+ CANCELLING: END_STATES
+ }
-class Deployment(Model):
+ @orm.validates('status')
+ def validate_status(self, key, value):
+ """Validation function that verifies execution status transitions are OK"""
+ try:
+ current_status = getattr(self, key)
+ except AttributeError:
+ return
+ valid_transitions = Execution.VALID_TRANSITIONS.get(current_status, [])
+ if all([current_status is not None,
+ current_status != value,
+ value not in valid_transitions]):
+ raise ValueError('Cannot change execution status from {current} to {new}'.format(
+ current=current_status,
+ new=value))
+ return value
+
+ deployment_id = foreign_key(Deployment.id)
+ blueprint_id = foreign_key(Blueprint.id)
+ _private_fields = ['deployment_id', 'blueprint_id']
+
+ created_at = Column(DateTime, index=True)
+ started_at = Column(DateTime, nullable=True, index=True)
+ ended_at = Column(DateTime, nullable=True, index=True)
+ error = Column(Text, nullable=True)
+ is_system_workflow = Column(Boolean, nullable=False, default=False)
+ parameters = Column(Dict)
+ status = Column(Enum(*STATES, name='execution_status'), default=PENDING)
+ workflow_name = Column(Text, nullable=False)
+
+ @declared_attr
+ def deployment(cls):
+ return one_to_many_relationship(cls, Deployment, cls.deployment_id)
+
+ @declared_attr
+ def blueprint(cls):
+ return one_to_many_relationship(cls, Blueprint, cls.blueprint_id)
+
+ def __str__(self):
+ return '<{0} id=`{1}` (status={2})>'.format(
+ self.__class__.__name__,
+ self.id,
+ self.status
+ )
+
+
+class DeploymentUpdate(SQLModelBase):
"""
- A Model which represents a deployment
+ Deployment update model representation.
"""
- id = Field(type=basestring, default=uuid_generator)
- description = Field(type=(basestring, NoneType))
- created_at = Field(type=datetime)
- updated_at = Field(type=datetime)
- blueprint_id = Field(type=basestring)
- workflows = Field(type=dict)
- inputs = Field(type=dict, default=lambda: {})
- policy_types = Field(type=dict, default=lambda: {})
- policy_triggers = Field(type=dict, default=lambda: {})
- groups = Field(type=dict, default=lambda: {})
- outputs = Field(type=dict, default=lambda: {})
- scaling_groups = Field(type=dict, default=lambda: {})
-
-
-class DeploymentUpdateStep(Model):
+ __tablename__ = 'deployment_updates'
+
+ deployment_id = foreign_key(Deployment.id)
+ execution_id = foreign_key(Execution.id, nullable=True)
+ _private_fields = ['execution_id', 'deployment_id']
+
+ created_at = Column(DateTime, nullable=False, index=True)
+ deployment_plan = Column(Dict, nullable=False)
+ deployment_update_node_instances = Column(Dict)
+ deployment_update_deployment = Column(Dict)
+ deployment_update_nodes = Column(Dict)
+ modified_entity_ids = Column(Dict)
+ state = Column(Text)
+
+ @declared_attr
+ def execution(cls):
+ return one_to_many_relationship(cls, Execution, cls.execution_id)
+
+ @declared_attr
+ def deployment(cls):
+ return one_to_many_relationship(cls, Deployment, cls.deployment_id)
+
+ def to_dict(self, suppress_error=False, **kwargs):
+ dep_update_dict = super(DeploymentUpdate, self).to_dict(suppress_error)
+ # Taking care of the fact the DeploymentSteps are objects
+ dep_update_dict['steps'] = [step.to_dict() for step in self.steps]
+ return dep_update_dict
+
+
+class DeploymentUpdateStep(SQLModelBase):
"""
- A Model which represents a deployment update step
+ Deployment update step model representation.
"""
- id = Field(type=basestring, default=uuid_generator)
- action = Field(type=basestring, choices=ACTION_TYPES)
- entity_type = Field(type=basestring, choices=ENTITY_TYPES)
- entity_id = Field(type=basestring)
- supported = Field(type=bool, default=True)
+ __tablename__ = 'deployment_update_steps'
+ _action_types = namedtuple('ACTION_TYPES', 'ADD, REMOVE, MODIFY')
+ ACTION_TYPES = _action_types(ADD='add', REMOVE='remove', MODIFY='modify')
+ _entity_types = namedtuple(
+ 'ENTITY_TYPES',
+ 'NODE, RELATIONSHIP, PROPERTY, OPERATION, WORKFLOW, OUTPUT, DESCRIPTION, GROUP, '
+ 'POLICY_TYPE, POLICY_TRIGGER, PLUGIN')
+ ENTITY_TYPES = _entity_types(
+ NODE='node',
+ RELATIONSHIP='relationship',
+ PROPERTY='property',
+ OPERATION='operation',
+ WORKFLOW='workflow',
+ OUTPUT='output',
+ DESCRIPTION='description',
+ GROUP='group',
+ POLICY_TYPE='policy_type',
+ POLICY_TRIGGER='policy_trigger',
+ PLUGIN='plugin'
+ )
+
+ deployment_update_id = foreign_key(DeploymentUpdate.id)
+ _private_fields = ['deployment_update_id']
+
+ action = Column(Enum(*ACTION_TYPES, name='action_type'), nullable=False)
+ entity_id = Column(Text, nullable=False)
+ entity_type = Column(Enum(*ENTITY_TYPES, name='entity_type'), nullable=False)
+
+ @declared_attr
+ def deployment_update(cls):
+ return one_to_many_relationship(cls,
+ DeploymentUpdate,
+ cls.deployment_update_id,
+ backreference='steps')
def __hash__(self):
return hash((self.id, self.entity_id))
@@ -148,265 +288,225 @@ class DeploymentUpdateStep(Model):
return False
-class DeploymentUpdate(Model):
+class DeploymentModification(SQLModelBase):
"""
- A Model which represents a deployment update
+ Deployment modification model representation.
"""
- INITIALIZING = 'initializing'
- SUCCESSFUL = 'successful'
- UPDATING = 'updating'
- FINALIZING = 'finalizing'
- EXECUTING_WORKFLOW = 'executing_workflow'
- FAILED = 'failed'
+ __tablename__ = 'deployment_modifications'
- STATES = [
- INITIALIZING,
- SUCCESSFUL,
- UPDATING,
- FINALIZING,
- EXECUTING_WORKFLOW,
- FAILED,
- ]
-
- # '{0}-{1}'.format(kwargs['deployment_id'], uuid4())
- id = Field(type=basestring, default=uuid_generator)
- deployment_id = Field(type=basestring)
- state = Field(type=basestring, choices=STATES, default=INITIALIZING)
- deployment_plan = Field()
- deployment_update_nodes = Field(default=None)
- deployment_update_node_instances = Field(default=None)
- deployment_update_deployment = Field(default=None)
- modified_entity_ids = Field(default=None)
- execution_id = Field(type=basestring)
- steps = IterPointerField(type=DeploymentUpdateStep, default=())
-
-
-class Execution(Model):
- """
- A Model which represents an execution
- """
+ STARTED = 'started'
+ FINISHED = 'finished'
+ ROLLEDBACK = 'rolledback'
- class _Validation(object):
-
- @staticmethod
- def execution_status_transition_validation(_, value, instance):
- """Validation function that verifies execution status transitions are OK"""
- try:
- current_status = instance.status
- except AttributeError:
- return
- valid_transitions = Execution.VALID_TRANSITIONS.get(current_status, [])
- if current_status != value and value not in valid_transitions:
- raise ValueError('Cannot change execution status from {current} to {new}'.format(
- current=current_status,
- new=value))
+ STATES = [STARTED, FINISHED, ROLLEDBACK]
+ END_STATES = [FINISHED, ROLLEDBACK]
- TERMINATED = 'terminated'
- FAILED = 'failed'
- CANCELLED = 'cancelled'
- PENDING = 'pending'
- STARTED = 'started'
- CANCELLING = 'cancelling'
- STATES = (
- TERMINATED,
- FAILED,
- CANCELLED,
- PENDING,
- STARTED,
- CANCELLING,
- )
- END_STATES = [TERMINATED, FAILED, CANCELLED]
- ACTIVE_STATES = [state for state in STATES if state not in END_STATES]
- VALID_TRANSITIONS = {
- PENDING: [STARTED, CANCELLED],
- STARTED: END_STATES + [CANCELLING],
- CANCELLING: END_STATES
- }
+ deployment_id = foreign_key(Deployment.id)
+ _private_fields = ['deployment_id']
- id = Field(type=basestring, default=uuid_generator)
- status = Field(type=basestring, choices=STATES,
- validation_func=_Validation.execution_status_transition_validation)
- deployment_id = Field(type=basestring)
- workflow_id = Field(type=basestring)
- blueprint_id = Field(type=basestring)
- created_at = Field(type=datetime, default=datetime.utcnow)
- started_at = Field(type=datetime, default=None)
- ended_at = Field(type=datetime, default=None)
- error = Field(type=basestring, default=None)
- parameters = Field()
+ context = Column(Dict)
+ created_at = Column(DateTime, nullable=False, index=True)
+ ended_at = Column(DateTime, index=True)
+ modified_nodes = Column(Dict)
+ node_instances = Column(Dict)
+ status = Column(Enum(*STATES, name='deployment_modification_status'))
+ @declared_attr
+ def deployment(cls):
+ return one_to_many_relationship(cls,
+ Deployment,
+ cls.deployment_id,
+ backreference='modifications')
-class Relationship(Model):
+
+class Node(SQLModelBase):
"""
- A Model which represents a relationship
+ Node model representation.
"""
- id = Field(type=basestring, default=uuid_generator)
- source_id = Field(type=basestring)
- target_id = Field(type=basestring)
- source_interfaces = Field(type=dict)
- source_operations = Field(type=dict)
- target_interfaces = Field(type=dict)
- target_operations = Field(type=dict)
- type = Field(type=basestring)
- type_hierarchy = Field(type=list)
- properties = Field(type=dict)
-
-
-class Node(Model):
+ __tablename__ = 'nodes'
+
+ # See base class for an explanation on these properties
+ is_id_unique = False
+
+ name = Column(Text, index=True)
+ _private_fields = ['deployment_id', 'host_id']
+ deployment_id = foreign_key(Deployment.id)
+ host_id = foreign_key('nodes.id', nullable=True)
+
+ @declared_attr
+ def deployment(cls):
+ return one_to_many_relationship(cls, Deployment, cls.deployment_id)
+
+ deploy_number_of_instances = Column(Integer, nullable=False)
+ # TODO: This probably should be a foreign key, but there's no guarantee
+ # in the code, currently, that the host will be created beforehand
+ max_number_of_instances = Column(Integer, nullable=False)
+ min_number_of_instances = Column(Integer, nullable=False)
+ number_of_instances = Column(Integer, nullable=False)
+ planned_number_of_instances = Column(Integer, nullable=False)
+ plugins = Column(Dict)
+ plugins_to_install = Column(Dict)
+ properties = Column(Dict)
+ operations = Column(Dict)
+ type = Column(Text, nullable=False, index=True)
+ type_hierarchy = Column(List)
+
+ @declared_attr
+ def host(cls):
+ return relationship_to_self(cls, cls.host_id, cls.id)
+
+
+class Relationship(SQLModelBase):
"""
- A Model which represents a node
+ Relationship model representation.
"""
- id = Field(type=basestring, default=uuid_generator)
- blueprint_id = Field(type=basestring)
- type = Field(type=basestring)
- type_hierarchy = Field()
- number_of_instances = Field(type=int)
- planned_number_of_instances = Field(type=int)
- deploy_number_of_instances = Field(type=int)
- host_id = Field(type=basestring, default=None)
- properties = Field(type=dict)
- operations = Field(type=dict)
- plugins = Field(type=list, default=())
- relationships = IterPointerField(type=Relationship)
- plugins_to_install = Field(type=list, default=())
- min_number_of_instances = Field(type=int)
- max_number_of_instances = Field(type=int)
-
- def relationships_by_target(self, target_id):
- """
- Retreives all of the relationship by target.
- :param target_id: the node id of the target of the relationship
- :yields: a relationship which target and node with the specified target_id
- """
- for relationship in self.relationships:
- if relationship.target_id == target_id:
- yield relationship
- # todo: maybe add here Exception if isn't exists (didn't yield one's)
+ __tablename__ = 'relationships'
+ _private_fields = ['source_node_id', 'target_node_id']
-class RelationshipInstance(Model):
- """
- A Model which represents a relationship instance
- """
- id = Field(type=basestring, default=uuid_generator)
- target_id = Field(type=basestring)
- target_name = Field(type=basestring)
- source_id = Field(type=basestring)
- source_name = Field(type=basestring)
- type = Field(type=basestring)
- relationship = PointerField(type=Relationship)
+ source_node_id = foreign_key(Node.id)
+ target_node_id = foreign_key(Node.id)
+
+ @declared_attr
+ def source_node(cls):
+ return one_to_many_relationship(cls,
+ Node,
+ cls.source_node_id,
+ 'outbound_relationships')
+
+ @declared_attr
+ def target_node(cls):
+ return one_to_many_relationship(cls,
+ Node,
+ cls.target_node_id,
+ 'inbound_relationships')
+ source_interfaces = Column(Dict)
+ source_operations = Column(Dict, nullable=False)
+ target_interfaces = Column(Dict)
+ target_operations = Column(Dict, nullable=False)
+ type = Column(String, nullable=False)
+ type_hierarchy = Column(List)
+ properties = Column(Dict)
-class NodeInstance(Model):
+
+class NodeInstance(SQLModelBase):
"""
- A Model which represents a node instance
+ Node instance model representation.
"""
- # todo: add statuses
- UNINITIALIZED = 'uninitialized'
- INITIALIZING = 'initializing'
- CREATING = 'creating'
- CONFIGURING = 'configuring'
- STARTING = 'starting'
- DELETED = 'deleted'
- STOPPING = 'stopping'
- DELETING = 'deleting'
- STATES = (
- UNINITIALIZED,
- INITIALIZING,
- CREATING,
- CONFIGURING,
- STARTING,
- DELETED,
- STOPPING,
- DELETING
- )
+ __tablename__ = 'node_instances'
- id = Field(type=basestring, default=uuid_generator)
- deployment_id = Field(type=basestring)
- runtime_properties = Field(type=dict)
- state = Field(type=basestring, choices=STATES, default=UNINITIALIZED)
- version = Field(type=(basestring, NoneType))
- relationship_instances = IterPointerField(type=RelationshipInstance)
- node = PointerField(type=Node)
- host_id = Field(type=basestring, default=None)
- scaling_groups = Field(default=())
-
- def relationships_by_target(self, target_id):
- """
- Retreives all of the relationship by target.
- :param target_id: the instance id of the target of the relationship
- :yields: a relationship instance which target and node with the specified target_id
- """
- for relationship_instance in self.relationship_instances:
- if relationship_instance.target_id == target_id:
- yield relationship_instance
- # todo: maybe add here Exception if isn't exists (didn't yield one's)
+ node_id = foreign_key(Node.id)
+ deployment_id = foreign_key(Deployment.id)
+ host_id = foreign_key('node_instances.id', nullable=True)
+
+ _private_fields = ['node_id', 'host_id']
+
+ name = Column(Text, index=True)
+ runtime_properties = Column(Dict)
+ scaling_groups = Column(Dict)
+ state = Column(Text, nullable=False)
+ version = Column(Integer, default=1)
+
+ @declared_attr
+ def deployment(cls):
+ return one_to_many_relationship(cls, Deployment, cls.deployment_id)
+
+ @declared_attr
+ def node(cls):
+ return one_to_many_relationship(cls, Node, cls.node_id)
+ @declared_attr
+ def host(cls):
+ return relationship_to_self(cls, cls.host_id, cls.id)
-class DeploymentModification(Model):
+
+class RelationshipInstance(SQLModelBase):
"""
- A Model which represents a deployment modification
+ Relationship instance model representation.
"""
- STARTED = 'started'
- FINISHED = 'finished'
- ROLLEDBACK = 'rolledback'
- END_STATES = [FINISHED, ROLLEDBACK]
+ __tablename__ = 'relationship_instances'
+
+ relationship_id = foreign_key(Relationship.id)
+ source_node_instance_id = foreign_key(NodeInstance.id)
+ target_node_instance_id = foreign_key(NodeInstance.id)
+
+ _private_fields = ['relationship_storage_id',
+ 'source_node_instance_id',
+ 'target_node_instance_id']
- id = Field(type=basestring, default=uuid_generator)
- deployment_id = Field(type=basestring)
- modified_nodes = Field(type=(dict, NoneType))
- added_and_related = IterPointerField(type=NodeInstance)
- removed_and_related = IterPointerField(type=NodeInstance)
- extended_and_related = IterPointerField(type=NodeInstance)
- reduced_and_related = IterPointerField(type=NodeInstance)
- # before_modification = IterPointerField(type=NodeInstance)
- status = Field(type=basestring, choices=(STARTED, FINISHED, ROLLEDBACK))
- created_at = Field(type=datetime)
- ended_at = Field(type=(datetime, NoneType))
- context = Field()
-
-
-class ProviderContext(Model):
+ @declared_attr
+ def source_node_instance(cls):
+ return one_to_many_relationship(cls,
+ NodeInstance,
+ cls.source_node_instance_id,
+ 'outbound_relationship_instances')
+
+ @declared_attr
+ def target_node_instance(cls):
+ return one_to_many_relationship(cls,
+ NodeInstance,
+ cls.target_node_instance_id,
+ 'inbound_relationship_instances')
+
+ @declared_attr
+ def relationship(cls):
+ return one_to_many_relationship(cls, Relationship, cls.relationship_id)
+
+
+class ProviderContext(SQLModelBase):
"""
- A Model which represents a provider context
+ Provider context model representation.
"""
- id = Field(type=basestring, default=uuid_generator)
- context = Field(type=dict)
- name = Field(type=basestring)
+ __tablename__ = 'provider_context'
+
+ name = Column(Text, nullable=False)
+ context = Column(Dict, nullable=False)
-class Plugin(Model):
+class Plugin(SQLModelBase):
"""
- A Model which represents a plugin
+ Plugin model representation.
"""
- id = Field(type=basestring, default=uuid_generator)
- package_name = Field(type=basestring)
- archive_name = Field(type=basestring)
- package_source = Field(type=dict)
- package_version = Field(type=basestring)
- supported_platform = Field(type=basestring)
- distribution = Field(type=basestring)
- distribution_version = Field(type=basestring)
- distribution_release = Field(type=basestring)
- wheels = Field()
- excluded_wheels = Field()
- supported_py_versions = Field(type=list)
- uploaded_at = Field(type=datetime)
-
-
-class Task(Model):
+ __tablename__ = 'plugins'
+
+ archive_name = Column(Text, nullable=False, index=True)
+ distribution = Column(Text)
+ distribution_release = Column(Text)
+ distribution_version = Column(Text)
+ excluded_wheels = Column(Dict)
+ package_name = Column(Text, nullable=False, index=True)
+ package_source = Column(Text)
+ package_version = Column(Text)
+ supported_platform = Column(Dict)
+ supported_py_versions = Column(Dict)
+ uploaded_at = Column(DateTime, nullable=False, index=True)
+ wheels = Column(Dict, nullable=False)
+
+
+class Task(SQLModelBase):
"""
A Model which represents an task
"""
- class _Validation(object):
+ __tablename__ = 'task'
+ node_instance_id = foreign_key(NodeInstance.id, nullable=True)
+ relationship_instance_id = foreign_key(RelationshipInstance.id, nullable=True)
+ execution_id = foreign_key(Execution.id, nullable=True)
+
+ _private_fields = ['node_instance_id',
+ 'relationship_instance_id',
+ 'execution_id']
- @staticmethod
- def validate_max_attempts(_, value, *args):
- """Validates that max attempts is either -1 or a positive number"""
- if value < 1 and value != Task.INFINITE_RETRIES:
- raise ValueError('Max attempts can be either -1 (infinite) or any positive number. '
- 'Got {value}'.format(value=value))
+ @declared_attr
+ def node_instance(cls):
+ return one_to_many_relationship(cls, NodeInstance, cls.node_instance_id)
+
+ @declared_attr
+ def relationship_instance(cls):
+ return one_to_many_relationship(cls,
+ RelationshipInstance,
+ cls.relationship_instance_id)
PENDING = 'pending'
RETRYING = 'retrying'
@@ -422,23 +522,51 @@ class Task(Model):
SUCCESS,
FAILED,
)
+
WAIT_STATES = [PENDING, RETRYING]
END_STATES = [SUCCESS, FAILED]
+
+ @orm.validates('max_attempts')
+ def validate_max_attempts(self, _, value): # pylint: disable=no-self-use
+ """Validates that max attempts is either -1 or a positive number"""
+ if value < 1 and value != Task.INFINITE_RETRIES:
+ raise ValueError('Max attempts can be either -1 (infinite) or any positive number. '
+ 'Got {value}'.format(value=value))
+ return value
+
INFINITE_RETRIES = -1
- id = Field(type=basestring, default=uuid_generator)
- status = Field(type=basestring, choices=STATES, default=PENDING)
- execution_id = Field(type=basestring)
- due_at = Field(type=datetime, default=datetime.utcnow)
- started_at = Field(type=datetime, default=None)
- ended_at = Field(type=datetime, default=None)
- max_attempts = Field(type=int, default=1, validation_func=_Validation.validate_max_attempts)
- retry_count = Field(type=int, default=0)
- retry_interval = Field(type=(int, float), default=0)
- ignore_failure = Field(type=bool, default=False)
+ status = Column(Enum(*STATES), name='status', default=PENDING)
+
+ due_at = Column(DateTime, default=datetime.utcnow)
+ started_at = Column(DateTime, default=None)
+ ended_at = Column(DateTime, default=None)
+ max_attempts = Column(Integer, default=1)
+ retry_count = Column(Integer, default=0)
+ retry_interval = Column(Float, default=0)
+ ignore_failure = Column(Boolean, default=False)
# Operation specific fields
- name = Field(type=basestring)
- operation_mapping = Field(type=basestring)
- actor = Field()
- inputs = Field(type=dict, default=lambda: {})
+ name = Column(String)
+ operation_mapping = Column(String)
+ inputs = Column(Dict)
+
+ @declared_attr
+ def execution(cls):
+ return one_to_many_relationship(cls, Execution, cls.execution_id)
+
+ @property
+ def actor(self):
+ """
+ Return the actor of the task
+ :return:
+ """
+ return self.node_instance or self.relationship_instance
+
+ @classmethod
+ def as_node_instance(cls, instance_id, **kwargs):
+ return cls(node_instance_id=instance_id, **kwargs)
+
+ @classmethod
+ def as_relationship_instance(cls, instance_id, **kwargs):
+ return cls(relationship_instance_id=instance_id, **kwargs)
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/aria/storage/sql_mapi.py
----------------------------------------------------------------------
diff --git a/aria/storage/sql_mapi.py b/aria/storage/sql_mapi.py
new file mode 100644
index 0000000..cde40c2
--- /dev/null
+++ b/aria/storage/sql_mapi.py
@@ -0,0 +1,382 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+SQLAlchemy based MAPI
+"""
+
+from sqlalchemy.exc import SQLAlchemyError
+
+from aria.utils.collections import OrderedDict
+from aria.storage import (
+ api,
+ exceptions
+)
+
+
+class SQLAlchemyModelAPI(api.ModelAPI):
+ """
+ SQL based MAPI.
+ """
+
+ def __init__(self,
+ engine,
+ session,
+ **kwargs):
+ super(SQLAlchemyModelAPI, self).__init__(**kwargs)
+ self._engine = engine
+ self._session = session
+
+ def get(self, entry_id, include=None, **kwargs):
+ """Return a single result based on the model class and element ID
+ """
+ query = self._get_query(include, {'id': entry_id})
+ result = query.first()
+
+ if not result:
+ raise exceptions.StorageError(
+ 'Requested {0} with ID `{1}` was not found'
+ .format(self.model_cls.__name__, entry_id)
+ )
+ return result
+
+ def get_by_name(self, entry_name, include=None, **kwargs):
+ assert hasattr(self.model_cls, 'name')
+ result = self.list(include=include, filters={'name': entry_name})
+ if not result:
+ raise exceptions.StorageError(
+ 'Requested {0} with NAME `{1}` was not found'
+ .format(self.model_cls.__name__, entry_name)
+ )
+ elif len(result) > 1:
+ raise exceptions.StorageError(
+ 'Requested {0} with NAME `{1}` returned more than 1 value'
+ .format(self.model_cls.__name__, entry_name)
+ )
+ else:
+ return result[0]
+
+ def list(self,
+ include=None,
+ filters=None,
+ pagination=None,
+ sort=None,
+ **kwargs):
+ query = self._get_query(include, filters, sort)
+
+ results, total, size, offset = self._paginate(query, pagination)
+
+ return ListResult(
+ items=results,
+ metadata=dict(total=total,
+ size=size,
+ offset=offset)
+ )
+
+ def iter(self,
+ include=None,
+ filters=None,
+ sort=None,
+ **kwargs):
+ """Return a (possibly empty) list of `model_class` results
+ """
+ return iter(self._get_query(include, filters, sort))
+
+ def put(self, entry, **kwargs):
+ """Create a `model_class` instance from a serializable `model` object
+
+ :param entry: A dict with relevant kwargs, or an instance of a class
+ that has a `to_dict` method, and whose attributes match the columns
+ of `model_class` (might also my just an instance of `model_class`)
+ :return: An instance of `model_class`
+ """
+ self._session.add(entry)
+ self._safe_commit()
+ return entry
+
+ def delete(self, entry, **kwargs):
+ """Delete a single result based on the model class and element ID
+ """
+ self._load_relationships(entry)
+ self._session.delete(entry)
+ self._safe_commit()
+ return entry
+
+ def update(self, entry, **kwargs):
+ """Add `instance` to the DB session, and attempt to commit
+
+ :return: The updated instance
+ """
+ return self.put(entry)
+
+ def refresh(self, entry):
+ """Reload the instance with fresh information from the DB
+
+ :param entry: Instance to be re-loaded from the DB
+ :return: The refreshed instance
+ """
+ self._session.refresh(entry)
+ self._load_relationships(entry)
+ return entry
+
+ def _destroy_connection(self):
+ pass
+
+ def _establish_connection(self):
+ pass
+
+ def create(self, checkfirst=True, **kwargs):
+ self.model_cls.__table__.create(self._engine, checkfirst=checkfirst)
+
+ def drop(self):
+ """
+ Drop the table from the storage.
+ :return:
+ """
+ self.model_cls.__table__.drop(self._engine)
+
+ def _safe_commit(self):
+ """Try to commit changes in the session. Roll back if exception raised
+ Excepts SQLAlchemy errors and rollbacks if they're caught
+ """
+ try:
+ self._session.commit()
+ except (SQLAlchemyError, ValueError) as e:
+ self._session.rollback()
+ raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))
+
+ def _get_base_query(self, include, joins):
+ """Create the initial query from the model class and included columns
+
+ :param include: A (possibly empty) list of columns to include in
+ the query
+ :return: An SQLAlchemy AppenderQuery object
+ """
+ # If only some columns are included, query through the session object
+ if include:
+ # Make sure that attributes come before association proxies
+ include.sort(key=lambda x: x.is_clause_element)
+ query = self._session.query(*include)
+ else:
+ # If all columns should be returned, query directly from the model
+ query = self._session.query(self.model_cls)
+
+ if not self._skip_joining(joins, include):
+ for join_table in joins:
+ query = query.join(join_table)
+
+ return query
+
+ @staticmethod
+ def _get_joins(model_class, columns):
+ """Get a list of all the tables on which we need to join
+
+ :param columns: A set of all columns involved in the query
+ """
+ joins = [] # Using a list instead of a set because order is important
+ for column_name in columns:
+ column = getattr(model_class, column_name)
+ while not column.is_attribute:
+ column = column.remote_attr
+ if column.is_attribute:
+ join_class = column.class_
+ else:
+ join_class = column.local_attr.class_
+
+ # Don't add the same class more than once
+ if join_class not in joins:
+ joins.append(join_class)
+ return joins
+
+ @staticmethod
+ def _skip_joining(joins, include):
+ """Dealing with an edge case where the only included column comes from
+ an other table. In this case, we mustn't join on the same table again
+
+ :param joins: A list of tables on which we're trying to join
+ :param include: The list of
+ :return: True if we need to skip joining
+ """
+ if not joins:
+ return True
+ join_table_names = [t.__tablename__ for t in joins]
+
+ if len(include) != 1:
+ return False
+
+ column = include[0]
+ if column.is_clause_element:
+ table_name = column.element.table.name
+ else:
+ table_name = column.class_.__tablename__
+ return table_name in join_table_names
+
+ @staticmethod
+ def _sort_query(query, sort=None):
+ """Add sorting clauses to the query
+
+ :param query: Base SQL query
+ :param sort: An optional dictionary where keys are column names to
+ sort by, and values are the order (asc/desc)
+ :return: An SQLAlchemy AppenderQuery object
+ """
+ if sort:
+ for column, order in sort.items():
+ if order == 'desc':
+ column = column.desc()
+ query = query.order_by(column)
+ return query
+
+ def _filter_query(self, query, filters):
+ """Add filter clauses to the query
+
+ :param query: Base SQL query
+ :param filters: An optional dictionary where keys are column names to
+ filter by, and values are values applicable for those columns (or lists
+ of such values)
+ :return: An SQLAlchemy AppenderQuery object
+ """
+ return self._add_value_filter(query, filters)
+
+ @staticmethod
+ def _add_value_filter(query, filters):
+ for column, value in filters.items():
+ if isinstance(value, (list, tuple)):
+ query = query.filter(column.in_(value))
+ else:
+ query = query.filter(column == value)
+
+ return query
+
+ def _get_query(self,
+ include=None,
+ filters=None,
+ sort=None):
+ """Get an SQL query object based on the params passed
+
+ :param model_class: SQL DB table class
+ :param include: An optional list of columns to include in the query
+ :param filters: An optional dictionary where keys are column names to
+ filter by, and values are values applicable for those columns (or lists
+ of such values)
+ :param sort: An optional dictionary where keys are column names to
+ sort by, and values are the order (asc/desc)
+ :return: A sorted and filtered query with only the relevant
+ columns
+ """
+ include, filters, sort, joins = self._get_joins_and_converted_columns(
+ include, filters, sort
+ )
+
+ query = self._get_base_query(include, joins)
+ query = self._filter_query(query, filters)
+ query = self._sort_query(query, sort)
+ return query
+
+ def _get_joins_and_converted_columns(self,
+ include,
+ filters,
+ sort):
+ """Get a list of tables on which we need to join and the converted
+ `include`, `filters` and `sort` arguments (converted to actual SQLA
+ column/label objects instead of column names)
+ """
+ include = include or []
+ filters = filters or dict()
+ sort = sort or OrderedDict()
+
+ all_columns = set(include) | set(filters.keys()) | set(sort.keys())
+ joins = self._get_joins(self.model_cls, all_columns)
+
+ include, filters, sort = self._get_columns_from_field_names(
+ include, filters, sort
+ )
+ return include, filters, sort, joins
+
+ def _get_columns_from_field_names(self,
+ include,
+ filters,
+ sort):
+ """Go over the optional parameters (include, filters, sort), and
+ replace column names with actual SQLA column objects
+ """
+ include = [self._get_column(c) for c in include]
+ filters = dict((self._get_column(c), filters[c]) for c in filters)
+ sort = OrderedDict((self._get_column(c), sort[c]) for c in sort)
+
+ return include, filters, sort
+
+ def _get_column(self, column_name):
+ """Return the column on which an action (filtering, sorting, etc.)
+ would need to be performed. Can be either an attribute of the class,
+ or an association proxy linked to a relationship the class has
+ """
+ column = getattr(self.model_cls, column_name)
+ if column.is_attribute:
+ return column
+ else:
+ # We need to get to the underlying attribute, so we move on to the
+ # next remote_attr until we reach one
+ while not column.remote_attr.is_attribute:
+ column = column.remote_attr
+ # Put a label on the remote attribute with the name of the column
+ return column.remote_attr.label(column_name)
+
+ @staticmethod
+ def _paginate(query, pagination):
+ """Paginate the query by size and offset
+
+ :param query: Current SQLAlchemy query object
+ :param pagination: An optional dict with size and offset keys
+ :return: A tuple with four elements:
+ - res ults: `size` items starting from `offset`
+ - the total count of items
+ - `size` [default: 0]
+ - `offset` [default: 0]
+ """
+ if pagination:
+ size = pagination.get('size', 0)
+ offset = pagination.get('offset', 0)
+ total = query.order_by(None).count() # Fastest way to count
+ results = query.limit(size).offset(offset).all()
+ return results, total, size, offset
+ else:
+ results = query.all()
+ return results, len(results), 0, 0
+
+ @staticmethod
+ def _load_relationships(instance):
+ """A helper method used to overcome a problem where the relationships
+ that rely on joins aren't being loaded automatically
+ """
+ for rel in instance.__mapper__.relationships:
+ getattr(instance, rel.key)
+
+
+class ListResult(object):
+ """
+ a ListResult contains results about the requested items.
+ """
+ def __init__(self, items, metadata):
+ self.items = items
+ self.metadata = metadata
+
+ def __len__(self):
+ return len(self.items)
+
+ def __iter__(self):
+ return iter(self.items)
+
+ def __getitem__(self, item):
+ return self.items[item]
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/aria/storage/structures.py
----------------------------------------------------------------------
diff --git a/aria/storage/structures.py b/aria/storage/structures.py
index b02366e..8dbd2a9 100644
--- a/aria/storage/structures.py
+++ b/aria/storage/structures.py
@@ -27,281 +27,218 @@ classes:
* Model - abstract model implementation.
"""
import json
-from itertools import count
-from uuid import uuid4
-
-from .exceptions import StorageError
-from ..logger import LoggerMixin
-from ..utils.validation import ValidatorMixin
-
-__all__ = (
- 'uuid_generator',
- 'Field',
- 'IterField',
- 'PointerField',
- 'IterPointerField',
- 'Model',
- 'Storage',
+
+from sqlalchemy.ext.mutable import Mutable
+from sqlalchemy.orm import relationship, backref
+from sqlalchemy.ext.declarative import declarative_base
+# pylint: disable=unused-import
+from sqlalchemy.ext.associationproxy import association_proxy
+from sqlalchemy import (
+ schema,
+ VARCHAR,
+ ARRAY,
+ Column,
+ Integer,
+ Text,
+ DateTime,
+ Boolean,
+ Enum,
+ String,
+ PickleType,
+ Float,
+ TypeDecorator,
+ ForeignKey,
+ orm,
)
+from aria.storage import exceptions
+
+Model = declarative_base()
-def uuid_generator():
- """
- wrapper function which generates ids
- """
- return str(uuid4())
+def foreign_key(foreign_key_column, nullable=False):
+ """Return a ForeignKey object with the relevant
-class Field(ValidatorMixin):
+ :param foreign_key_column: Unique id column in the parent table
+ :param nullable: Should the column be allowed to remain empty
"""
- A single field implementation
+ return Column(
+ ForeignKey(foreign_key_column, ondelete='CASCADE'),
+ nullable=nullable
+ )
+
+
+def one_to_many_relationship(child_class,
+ parent_class,
+ foreign_key_column,
+ backreference=None):
+ """Return a one-to-many SQL relationship object
+ Meant to be used from inside the *child* object
+
+ :param parent_class: Class of the parent table
+ :param child_class: Class of the child table
+ :param foreign_key_column: The column of the foreign key
+ :param backreference: The name to give to the reference to the child
"""
- NO_DEFAULT = 'NO_DEFAULT'
-
- try:
- # python 3 syntax
- _next_id = count().__next__
- except AttributeError:
- # python 2 syntax
- _next_id = count().next
- _ATTRIBUTE_NAME = '_cache_{0}'.format
-
- def __init__(
- self,
- type=None,
- choices=(),
- validation_func=None,
- default=NO_DEFAULT,
- **kwargs):
- """
- Simple field manager.
+ backreference = backreference or child_class.__tablename__
+ return relationship(
+ parent_class,
+ primaryjoin=lambda: parent_class.id == foreign_key_column,
+ # The following line make sure that when the *parent* is
+ # deleted, all its connected children are deleted as well
+ backref=backref(backreference, cascade='all')
+ )
- :param type: possible type of the field.
- :param choices: a set of possible field values.
- :param default: default field value.
- :param kwargs: kwargs to be passed to next in line classes.
- """
- self.type = type
- self.choices = choices
- self.default = default
- self.validation_func = validation_func
- super(Field, self).__init__(**kwargs)
-
- def __get__(self, instance, owner):
- if instance is None:
- return self
- field_name = self._field_name(instance)
- try:
- return getattr(instance, self._ATTRIBUTE_NAME(field_name))
- except AttributeError as exc:
- if self.default == self.NO_DEFAULT:
- raise AttributeError(
- str(exc).replace(self._ATTRIBUTE_NAME(field_name), field_name))
-
- default_value = self.default() if callable(self.default) else self.default
- setattr(instance, self._ATTRIBUTE_NAME(field_name), default_value)
- return default_value
-
- def __set__(self, instance, value):
- field_name = self._field_name(instance)
- self.validate_value(field_name, value, instance)
- setattr(instance, self._ATTRIBUTE_NAME(field_name), value)
-
- def validate_value(self, name, value, instance):
- """
- Validates the value of the field.
- :param name: the name of the field.
- :param value: the value of the field.
- :param instance: the instance containing the field.
- """
- if self.default != self.NO_DEFAULT and value == self.default:
- return
- if self.type:
- self.validate_instance(name, value, self.type)
- if self.choices:
- self.validate_in_choice(name, value, self.choices)
- if self.validation_func:
- self.validation_func(name, value, instance)
-
- def _field_name(self, instance):
- """
- retrieves the field name from the instance.
-
- :param Field instance: the instance which holds the field.
- :return: name of the field
- :rtype: basestring
- """
- for name, member in vars(instance.__class__).iteritems():
- if member is self:
- return name
+def relationship_to_self(self_cls, parent_key, self_key):
+ return relationship(
+ self_cls,
+ foreign_keys=parent_key,
+ remote_side=self_key
+ )
-class IterField(Field):
+class _MutableType(TypeDecorator):
"""
- Represents an iterable field.
+ Dict representation of type.
"""
- def __init__(self, **kwargs):
- """
- Simple iterable field manager.
- This field type don't have choices option.
-
- :param kwargs: kwargs to be passed to next in line classes.
- """
- super(IterField, self).__init__(choices=(), **kwargs)
+ @property
+ def python_type(self):
+ raise NotImplementedError
- def validate_value(self, name, values, *args):
- """
- Validates the value of each iterable value.
+ def process_literal_param(self, value, dialect):
+ pass
- :param name: the name of the field.
- :param values: the values of the field.
- """
- for value in values:
- self.validate_instance(name, value, self.type)
+ impl = VARCHAR
+ def process_bind_param(self, value, dialect):
+ if value is not None:
+ value = json.dumps(value)
+ return value
-class PointerField(Field):
- """
- A single pointer field implementation.
-
- Any PointerField points via id to another document.
- """
+ def process_result_value(self, value, dialect):
+ if value is not None:
+ value = json.loads(value)
+ return value
- def __init__(self, type, **kwargs):
- assert issubclass(type, Model)
- super(PointerField, self).__init__(type=type, **kwargs)
+class _DictType(_MutableType):
+ @property
+ def python_type(self):
+ return dict
-class IterPointerField(IterField, PointerField):
- """
- An iterable pointers field.
- Any IterPointerField points via id to other documents.
- """
- pass
+class _ListType(_MutableType):
+ @property
+ def python_type(self):
+ return list
-class Model(object):
+class _MutableDict(Mutable, dict):
"""
- Base class for all of the storage models.
+ Enables tracking for dict values.
"""
- id = None
+ @classmethod
+ def coerce(cls, key, value):
+ "Convert plain dictionaries to MutableDict."
- def __init__(self, **fields):
- """
- Abstract class for any model in the storage.
- The Initializer creates attributes according to the (keyword arguments) that given
- Each value is validated according to the Field.
- Each model has to have and ID Field.
+ if not isinstance(value, _MutableDict):
+ if isinstance(value, dict):
+ return _MutableDict(value)
- :param fields: each item is validated and transformed into instance attributes.
- """
- self._assert_model_have_id_field(**fields)
- missing_fields, unexpected_fields = self._setup_fields(fields)
+ # this call will raise ValueError
+ try:
+ return Mutable.coerce(key, value)
+ except ValueError as e:
+ raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))
+ else:
+ return value
- if missing_fields:
- raise StorageError(
- 'Model {name} got missing keyword arguments: {fields}'.format(
- name=self.__class__.__name__, fields=missing_fields))
+ def __setitem__(self, key, value):
+ "Detect dictionary set events and emit change events."
- if unexpected_fields:
- raise StorageError(
- 'Model {name} got unexpected keyword arguments: {fields}'.format(
- name=self.__class__.__name__, fields=unexpected_fields))
+ dict.__setitem__(self, key, value)
+ self.changed()
- def __repr__(self):
- return '{name}(fields={0})'.format(sorted(self.fields), name=self.__class__.__name__)
+ def __delitem__(self, key):
+ "Detect dictionary del events and emit change events."
- def __eq__(self, other):
- return (
- isinstance(other, self.__class__) and
- self.fields_dict == other.fields_dict)
+ dict.__delitem__(self, key)
+ self.changed()
- @property
- def fields(self):
- """
- Iterates over the fields of the model.
- :yields: the class's field name
- """
- for name, field in vars(self.__class__).items():
- if isinstance(field, Field):
- yield name
- @property
- def fields_dict(self):
- """
- Transforms the instance attributes into a dict.
+class _MutableList(Mutable, list):
- :return: all fields in dict format.
- :rtype dict
- """
- return dict((name, getattr(self, name)) for name in self.fields)
+ @classmethod
+ def coerce(cls, key, value):
+ "Convert plain dictionaries to MutableDict."
- @property
- def json(self):
- """
- Transform the dict of attributes into json
- :return:
- """
- return json.dumps(self.fields_dict)
+ if not isinstance(value, _MutableList):
+ if isinstance(value, list):
+ return _MutableList(value)
- @classmethod
- def _assert_model_have_id_field(cls, **fields_initializer_values):
- if not getattr(cls, 'id', None):
- raise StorageError('Model {cls.__name__} must have id field'.format(cls=cls))
-
- if cls.id.default == cls.id.NO_DEFAULT and 'id' not in fields_initializer_values:
- raise StorageError(
- 'Model {cls.__name__} is missing required '
- 'keyword-only argument: "id"'.format(cls=cls))
-
- def _setup_fields(self, input_fields):
- missing = []
- for field_name in self.fields:
+ # this call will raise ValueError
try:
- field_obj = input_fields.pop(field_name)
- setattr(self, field_name, field_obj)
- except KeyError:
- field = getattr(self.__class__, field_name)
- if field.default == field.NO_DEFAULT:
- missing.append(field_name)
+ return Mutable.coerce(key, value)
+ except ValueError as e:
+ raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))
+ else:
+ return value
+
+ def __setitem__(self, key, value):
+ list.__setitem__(self, key, value)
+ self.changed()
+
+ def __delitem__(self, key):
+ list.__delitem__(self, key)
+
- unexpected_fields = input_fields.keys()
- return missing, unexpected_fields
+Dict = _MutableDict.as_mutable(_DictType)
+List = _MutableList.as_mutable(_ListType)
-class Storage(LoggerMixin):
+class SQLModelBase(Model):
"""
- Represents the storage
+ Abstract base class for all SQL models that allows [de]serialization
"""
- def __init__(self, driver, items=(), **kwargs):
- super(Storage, self).__init__(**kwargs)
- self.driver = driver
- self.registered = {}
- for item in items:
- self.register(item)
- self.logger.debug('{name} object is ready: {0!r}'.format(
- self, name=self.__class__.__name__))
+ # SQLAlchemy syntax
+ __abstract__ = True
- def __repr__(self):
- return '{name}(driver={self.driver})'.format(
- name=self.__class__.__name__, self=self)
+ # This would be overridden once the models are created. Created for pylint.
+ __table__ = None
+
+ _private_fields = []
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
- def __getattr__(self, item):
- try:
- return self.registered[item]
- except KeyError:
- return super(Storage, self).__getattribute__(item)
+ def to_dict(self, suppress_error=False):
+ """Return a dict representation of the model
- def setup(self):
+ :param suppress_error: If set to True, sets `None` to attributes that
+ it's unable to retrieve (e.g., if a relationship wasn't established
+ yet, and so it's impossible to access a property through it)
"""
- Setup and create all storage items
+ if suppress_error:
+ res = dict()
+ for field in self.fields():
+ try:
+ field_value = getattr(self, field)
+ except AttributeError:
+ field_value = None
+ res[field] = field_value
+ else:
+ # Can't simply call here `self.to_response()` because inheriting
+ # class might override it, but we always need the same code here
+ res = dict((f, getattr(self, f)) for f in self.fields())
+ return res
+
+ @classmethod
+ def fields(cls):
+ """Return the list of field names for this table
+
+ Mostly for backwards compatibility in the code (that uses `fields`)
"""
- for name, api in self.registered.iteritems():
- try:
- api.create()
- self.logger.debug(
- 'setup {name} in storage {self!r}'.format(name=name, self=self))
- except StorageError:
- pass
+ return set(cls.__table__.columns.keys()) - set(cls._private_fields)
+
+ def __repr__(self):
+ return '<{0} id=`{1}`>'.format(self.__class__.__name__, self.id)
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/aria/utils/application.py
----------------------------------------------------------------------
diff --git a/aria/utils/application.py b/aria/utils/application.py
index b1a7fcc..113e054 100644
--- a/aria/utils/application.py
+++ b/aria/utils/application.py
@@ -117,7 +117,7 @@ class StorageManager(LoggerMixin):
updated_at=now,
main_file_name=main_file_name,
)
- self.model_storage.blueprint.store(blueprint)
+ self.model_storage.blueprint.put(blueprint)
self.logger.debug('created blueprint model storage entry')
def create_nodes_storage(self):
@@ -138,7 +138,7 @@ class StorageManager(LoggerMixin):
scalable = node_copy.pop('capabilities')['scalable']['properties']
for index, relationship in enumerate(node_copy['relationships']):
relationship = self.model_storage.relationship.model_cls(**relationship)
- self.model_storage.relationship.store(relationship)
+ self.model_storage.relationship.put(relationship)
node_copy['relationships'][index] = relationship
node_copy = self.model_storage.node.model_cls(
@@ -149,7 +149,7 @@ class StorageManager(LoggerMixin):
max_number_of_instances=scalable['max_instances'],
number_of_instances=scalable['current_instances'],
**node_copy)
- self.model_storage.node.store(node_copy)
+ self.model_storage.node.put(node_copy)
def create_deployment_storage(self):
"""
@@ -190,7 +190,7 @@ class StorageManager(LoggerMixin):
created_at=now,
updated_at=now
)
- self.model_storage.deployment.store(deployment)
+ self.model_storage.deployment.put(deployment)
self.logger.debug('created deployment model storage entry')
def create_node_instances_storage(self):
@@ -213,7 +213,7 @@ class StorageManager(LoggerMixin):
type=relationship_instance['type'],
target_id=relationship_instance['target_id'])
relationship_instances.append(relationship_instance_model)
- self.model_storage.relationship_instance.store(relationship_instance_model)
+ self.model_storage.relationship_instance.put(relationship_instance_model)
node_instance_model = self.model_storage.node_instance.model_cls(
node=node_model,
@@ -224,7 +224,7 @@ class StorageManager(LoggerMixin):
version='1.0',
relationship_instances=relationship_instances)
- self.model_storage.node_instance.store(node_instance_model)
+ self.model_storage.node_instance.put(node_instance_model)
self.logger.debug('created node-instances model storage entries')
def create_plugin_storage(self, plugin_id, source):
@@ -258,7 +258,7 @@ class StorageManager(LoggerMixin):
supported_py_versions=plugin.get('supported_python_versions'),
uploaded_at=now
)
- self.model_storage.plugin.store(plugin)
+ self.model_storage.plugin.put(plugin)
self.logger.debug('created plugin model storage entry')
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/requirements.txt
----------------------------------------------------------------------
diff --git a/requirements.txt b/requirements.txt
index e6d5393..7e87c67 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -23,3 +23,4 @@ Jinja2==2.8
shortuuid==0.4.3
CacheControl[filecache]==0.11.6
clint==0.5.1
+SQLAlchemy==1.1.4
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/tests/mock/context.py
----------------------------------------------------------------------
diff --git a/tests/mock/context.py b/tests/mock/context.py
index 5fda07e..1904140 100644
--- a/tests/mock/context.py
+++ b/tests/mock/context.py
@@ -15,23 +15,53 @@
from aria import application_model_storage
from aria.orchestrator import context
+from aria.storage.sql_mapi import SQLAlchemyModelAPI
from . import models
-from ..storage import InMemoryModelDriver
-def simple(**kwargs):
- storage = application_model_storage(InMemoryModelDriver())
- storage.setup()
- storage.blueprint.store(models.get_blueprint())
- storage.deployment.store(models.get_deployment())
+def simple(api_kwargs, **kwargs):
+ model_storage = application_model_storage(SQLAlchemyModelAPI, api_kwargs=api_kwargs)
+ blueprint = models.get_blueprint()
+ model_storage.blueprint.put(blueprint)
+ deployment = models.get_deployment(blueprint)
+ model_storage.deployment.put(deployment)
+
+ #################################################################################
+ # Creating a simple deployment with node -> node as a graph
+
+ dependency_node = models.get_dependency_node(deployment)
+ model_storage.node.put(dependency_node)
+ storage_dependency_node = model_storage.node.get(dependency_node.id)
+
+ dependency_node_instance = models.get_dependency_node_instance(storage_dependency_node)
+ model_storage.node_instance.put(dependency_node_instance)
+ storage_dependency_node_instance = model_storage.node_instance.get(dependency_node_instance.id)
+
+ dependent_node = models.get_dependent_node(deployment)
+ model_storage.node.put(dependent_node)
+ storage_dependent_node = model_storage.node.get(dependent_node.id)
+
+ dependent_node_instance = models.get_dependent_node_instance(storage_dependent_node)
+ model_storage.node_instance.put(dependent_node_instance)
+ storage_dependent_node_instance = model_storage.node_instance.get(dependent_node_instance.id)
+
+ relationship = models.get_relationship(storage_dependent_node, storage_dependency_node)
+ model_storage.relationship.put(relationship)
+ storage_relationship = model_storage.relationship.get(relationship.id)
+ relationship_instance = models.get_relationship_instance(
+ relationship=storage_relationship,
+ target_instance=storage_dependency_node_instance,
+ source_instance=storage_dependent_node_instance
+ )
+ model_storage.relationship_instance.put(relationship_instance)
+
final_kwargs = dict(
name='simple_context',
- model_storage=storage,
+ model_storage=model_storage,
resource_storage=None,
- deployment_id=models.DEPLOYMENT_ID,
- workflow_id=models.WORKFLOW_ID,
- execution_id=models.EXECUTION_ID,
+ deployment_id=deployment.id,
+ workflow_name=models.WORKFLOW_NAME,
task_max_attempts=models.TASK_MAX_ATTEMPTS,
task_retry_interval=models.TASK_RETRY_INTERVAL
)
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/tests/mock/models.py
----------------------------------------------------------------------
diff --git a/tests/mock/models.py b/tests/mock/models.py
index 327b0b9..e2e3d2f 100644
--- a/tests/mock/models.py
+++ b/tests/mock/models.py
@@ -19,24 +19,24 @@ from aria.storage import models
from . import operations
-DEPLOYMENT_ID = 'test_deployment_id'
-BLUEPRINT_ID = 'test_blueprint_id'
-WORKFLOW_ID = 'test_workflow_id'
-EXECUTION_ID = 'test_execution_id'
+DEPLOYMENT_NAME = 'test_deployment_id'
+BLUEPRINT_NAME = 'test_blueprint_id'
+WORKFLOW_NAME = 'test_workflow_id'
+EXECUTION_NAME = 'test_execution_id'
TASK_RETRY_INTERVAL = 1
TASK_MAX_ATTEMPTS = 1
-DEPENDENCY_NODE_ID = 'dependency_node'
-DEPENDENCY_NODE_INSTANCE_ID = 'dependency_node_instance'
-DEPENDENT_NODE_ID = 'dependent_node'
-DEPENDENT_NODE_INSTANCE_ID = 'dependent_node_instance'
+DEPENDENCY_NODE_NAME = 'dependency_node'
+DEPENDENCY_NODE_INSTANCE_NAME = 'dependency_node_instance'
+DEPENDENT_NODE_NAME = 'dependent_node'
+DEPENDENT_NODE_INSTANCE_NAME = 'dependent_node_instance'
+RELATIONSHIP_NAME = 'relationship'
+RELATIONSHIP_INSTANCE_NAME = 'relationship_instance'
-def get_dependency_node():
+def get_dependency_node(deployment):
return models.Node(
- id=DEPENDENCY_NODE_ID,
- host_id=DEPENDENCY_NODE_ID,
- blueprint_id=BLUEPRINT_ID,
+ name=DEPENDENCY_NODE_NAME,
type='test_node_type',
type_hierarchy=[],
number_of_instances=1,
@@ -44,28 +44,28 @@ def get_dependency_node():
deploy_number_of_instances=1,
properties={},
operations=dict((key, {}) for key in operations.NODE_OPERATIONS),
- relationships=[],
min_number_of_instances=1,
max_number_of_instances=1,
+ deployment_id=deployment.id
)
-def get_dependency_node_instance(dependency_node=None):
+def get_dependency_node_instance(dependency_node):
return models.NodeInstance(
- id=DEPENDENCY_NODE_INSTANCE_ID,
- host_id=DEPENDENCY_NODE_INSTANCE_ID,
- deployment_id=DEPLOYMENT_ID,
+ name=DEPENDENCY_NODE_INSTANCE_NAME,
runtime_properties={'ip': '1.1.1.1'},
version=None,
- relationship_instances=[],
- node=dependency_node or get_dependency_node()
+ node_id=dependency_node.id,
+ deployment_id=dependency_node.deployment.id,
+ state='',
+ scaling_groups={}
)
def get_relationship(source=None, target=None):
return models.Relationship(
- source_id=source.id if source is not None else DEPENDENT_NODE_ID,
- target_id=target.id if target is not None else DEPENDENCY_NODE_ID,
+ source_node_id=source.id,
+ target_node_id=target.id,
source_interfaces={},
source_operations=dict((key, {}) for key in operations.RELATIONSHIP_OPERATIONS),
target_interfaces={},
@@ -76,23 +76,18 @@ def get_relationship(source=None, target=None):
)
-def get_relationship_instance(source_instance=None, target_instance=None, relationship=None):
+def get_relationship_instance(source_instance, target_instance, relationship):
return models.RelationshipInstance(
- target_id=target_instance.id if target_instance else DEPENDENCY_NODE_INSTANCE_ID,
- target_name='test_target_name',
- source_id=source_instance.id if source_instance else DEPENDENT_NODE_INSTANCE_ID,
- source_name='test_source_name',
- type='some_type',
- relationship=relationship or get_relationship(target_instance.node
- if target_instance else None)
+ relationship_id=relationship.id,
+ target_node_instance_id=target_instance.id,
+ source_node_instance_id=source_instance.id,
)
-def get_dependent_node(relationship=None):
+def get_dependent_node(deployment):
return models.Node(
- id=DEPENDENT_NODE_ID,
- host_id=DEPENDENT_NODE_ID,
- blueprint_id=BLUEPRINT_ID,
+ name=DEPENDENT_NODE_NAME,
+ deployment_id=deployment.id,
type='test_node_type',
type_hierarchy=[],
number_of_instances=1,
@@ -100,21 +95,20 @@ def get_dependent_node(relationship=None):
deploy_number_of_instances=1,
properties={},
operations=dict((key, {}) for key in operations.NODE_OPERATIONS),
- relationships=[relationship or get_relationship()],
min_number_of_instances=1,
max_number_of_instances=1,
)
-def get_dependent_node_instance(relationship_instance=None, dependent_node=None):
+def get_dependent_node_instance(dependent_node):
return models.NodeInstance(
- id=DEPENDENT_NODE_INSTANCE_ID,
- host_id=DEPENDENT_NODE_INSTANCE_ID,
- deployment_id=DEPLOYMENT_ID,
+ name=DEPENDENT_NODE_INSTANCE_NAME,
runtime_properties={},
version=None,
- relationship_instances=[relationship_instance or get_relationship_instance()],
- node=dependent_node or get_dependency_node()
+ node_id=dependent_node.id,
+ deployment_id=dependent_node.deployment.id,
+ state='',
+ scaling_groups={}
)
@@ -122,7 +116,7 @@ def get_blueprint():
now = datetime.now()
return models.Blueprint(
plan={},
- id=BLUEPRINT_ID,
+ name=BLUEPRINT_NAME,
description=None,
created_at=now,
updated_at=now,
@@ -130,25 +124,31 @@ def get_blueprint():
)
-def get_execution():
+def get_execution(deployment):
return models.Execution(
- id=EXECUTION_ID,
+ deployment_id=deployment.id,
+ blueprint_id=deployment.blueprint.id,
status=models.Execution.STARTED,
- deployment_id=DEPLOYMENT_ID,
- workflow_id=WORKFLOW_ID,
- blueprint_id=BLUEPRINT_ID,
+ workflow_name=WORKFLOW_NAME,
started_at=datetime.utcnow(),
parameters=None
)
-def get_deployment():
+def get_deployment(blueprint):
now = datetime.utcnow()
return models.Deployment(
- id=DEPLOYMENT_ID,
- description=None,
+ name=DEPLOYMENT_NAME,
+ blueprint_id=blueprint.id,
+ description='',
created_at=now,
updated_at=now,
- blueprint_id=BLUEPRINT_ID,
- workflows={}
+ workflows={},
+ inputs={},
+ groups={},
+ permalink='',
+ policy_triggers={},
+ policy_types={},
+ outputs={},
+ scaling_groups={},
)
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/tests/orchestrator/context/test_operation.py
----------------------------------------------------------------------
diff --git a/tests/orchestrator/context/test_operation.py b/tests/orchestrator/context/test_operation.py
index 6b3e28d..b5f52a3 100644
--- a/tests/orchestrator/context/test_operation.py
+++ b/tests/orchestrator/context/test_operation.py
@@ -23,7 +23,7 @@ from aria.orchestrator import context
from aria.orchestrator.workflows import api
from aria.orchestrator.workflows.executor import thread
-from tests import mock
+from tests import mock, storage
from . import (
op_path,
op_name,
@@ -34,8 +34,10 @@ global_test_holder = {}
@pytest.fixture
-def ctx():
- return mock.context.simple()
+def ctx(tmpdir):
+ context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir)))
+ yield context
+ storage.release_sqlite_storage(context.model)
@pytest.fixture
@@ -50,14 +52,13 @@ def executor():
def test_node_operation_task_execution(ctx, executor):
operation_name = 'aria.interfaces.lifecycle.create'
- node = mock.models.get_dependency_node()
+ node = ctx.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME)
node.operations[operation_name] = {
'operation': op_path(my_operation, module_path=__name__)
}
- node_instance = mock.models.get_dependency_node_instance(node)
- ctx.model.node.store(node)
- ctx.model.node_instance.store(node_instance)
+ ctx.model.node.update(node)
+ node_instance = ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME)
inputs = {'putput': True}
@@ -90,26 +91,19 @@ def test_node_operation_task_execution(ctx, executor):
def test_relationship_operation_task_execution(ctx, executor):
operation_name = 'aria.interfaces.relationship_lifecycle.postconfigure'
-
- dependency_node = mock.models.get_dependency_node()
- dependency_node_instance = mock.models.get_dependency_node_instance()
- relationship = mock.models.get_relationship(target=dependency_node)
+ relationship = ctx.model.relationship.list()[0]
relationship.source_operations[operation_name] = {
'operation': op_path(my_operation, module_path=__name__)
}
- relationship_instance = mock.models.get_relationship_instance(
- target_instance=dependency_node_instance,
- relationship=relationship)
- dependent_node = mock.models.get_dependent_node()
- dependent_node_instance = mock.models.get_dependent_node_instance(
- relationship_instance=relationship_instance,
- dependent_node=dependency_node)
- ctx.model.node.store(dependency_node)
- ctx.model.node_instance.store(dependency_node_instance)
- ctx.model.relationship.store(relationship)
- ctx.model.relationship_instance.store(relationship_instance)
- ctx.model.node.store(dependent_node)
- ctx.model.node_instance.store(dependent_node_instance)
+ ctx.model.relationship.update(relationship)
+ relationship_instance = ctx.model.relationship_instance.list()[0]
+
+ dependency_node = ctx.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME)
+ dependency_node_instance = \
+ ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME)
+ dependent_node = ctx.model.node.get_by_name(mock.models.DEPENDENT_NODE_NAME)
+ dependent_node_instance = \
+ ctx.model.node_instance.get_by_name(mock.models.DEPENDENT_NODE_INSTANCE_NAME)
inputs = {'putput': True}
@@ -146,11 +140,49 @@ def test_relationship_operation_task_execution(ctx, executor):
assert operation_context.source_node_instance == dependent_node_instance
+def test_invalid_task_operation_id(ctx, executor):
+ """
+ Checks that the right id is used. The task created with id == 1, thus running the task on
+ node_instance with id == 2. will check that indeed the node_instance uses the correct id.
+ :param ctx:
+ :param executor:
+ :return:
+ """
+ operation_name = 'aria.interfaces.lifecycle.create'
+ other_node_instance, node_instance = ctx.model.node_instance.list()
+ assert other_node_instance.id == 1
+ assert node_instance.id == 2
+
+ node = node_instance.node
+ node.operations[operation_name] = {
+ 'operation': op_path(get_node_instance_id, module_path=__name__)
+
+ }
+ ctx.model.node.update(node)
+
+ @workflow
+ def basic_workflow(graph, **_):
+ graph.add_tasks(
+ api.task.OperationTask.node_instance(name=operation_name, instance=node_instance)
+ )
+
+ execute(workflow_func=basic_workflow, workflow_context=ctx, executor=executor)
+
+ op_node_instance_id = global_test_holder[op_name(node_instance, operation_name)]
+ assert op_node_instance_id == node_instance.id
+ assert op_node_instance_id != other_node_instance.id
+
+
@operation
def my_operation(ctx, **_):
global_test_holder[ctx.name] = ctx
+@operation
+def get_node_instance_id(ctx, **_):
+ global_test_holder[ctx.name] = ctx.node_instance.id
+
+
@pytest.fixture(autouse=True)
def cleanup():
global_test_holder.clear()
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/tests/orchestrator/context/test_toolbelt.py
----------------------------------------------------------------------
diff --git a/tests/orchestrator/context/test_toolbelt.py b/tests/orchestrator/context/test_toolbelt.py
index 547e62b..da46696 100644
--- a/tests/orchestrator/context/test_toolbelt.py
+++ b/tests/orchestrator/context/test_toolbelt.py
@@ -21,7 +21,7 @@ from aria.orchestrator.workflows import api
from aria.orchestrator.workflows.executor import thread
from aria.orchestrator.context.toolbelt import RelationshipToolBelt
-from tests import mock
+from tests import mock, storage
from . import (
op_path,
op_name,
@@ -32,8 +32,10 @@ global_test_holder = {}
@pytest.fixture
-def workflow_context():
- return mock.context.simple()
+def workflow_context(tmpdir):
+ context = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir)))
+ yield context
+ storage.release_sqlite_storage(context.model)
@pytest.fixture
@@ -45,63 +47,39 @@ def executor():
result.close()
-def _create_simple_model_in_storage(workflow_context):
- dependency_node = mock.models.get_dependency_node()
- dependency_node_instance = mock.models.get_dependency_node_instance(
- dependency_node=dependency_node)
- relationship = mock.models.get_relationship(target=dependency_node)
- relationship_instance = mock.models.get_relationship_instance(
- target_instance=dependency_node_instance, relationship=relationship)
- dependent_node = mock.models.get_dependent_node()
- dependent_node_instance = mock.models.get_dependent_node_instance(
- relationship_instance=relationship_instance, dependent_node=dependency_node)
- workflow_context.model.node.store(dependency_node)
- workflow_context.model.node_instance.store(dependency_node_instance)
- workflow_context.model.relationship.store(relationship)
- workflow_context.model.relationship_instance.store(relationship_instance)
- workflow_context.model.node.store(dependent_node)
- workflow_context.model.node_instance.store(dependent_node_instance)
- return dependency_node, dependency_node_instance, \
- dependent_node, dependent_node_instance, \
- relationship, relationship_instance
+def _get_elements(workflow_context):
+ dependency_node = workflow_context.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME)
+ dependency_node.host_id = dependency_node.id
+ workflow_context.model.node.update(dependency_node)
+ dependency_node_instance = workflow_context.model.node_instance.get_by_name(
+ mock.models.DEPENDENCY_NODE_INSTANCE_NAME)
+ dependency_node_instance.host_id = dependency_node_instance.id
+ workflow_context.model.node_instance.update(dependency_node_instance)
-def test_host_ip(workflow_context, executor):
- operation_name = 'aria.interfaces.lifecycle.create'
- dependency_node, dependency_node_instance, _, _, _, _ = \
- _create_simple_model_in_storage(workflow_context)
- dependency_node.operations[operation_name] = {
- 'operation': op_path(host_ip, module_path=__name__)
-
- }
- workflow_context.model.node.store(dependency_node)
- inputs = {'putput': True}
-
- @workflow
- def basic_workflow(graph, **_):
- graph.add_tasks(
- api.task.OperationTask.node_instance(
- instance=dependency_node_instance,
- name=operation_name,
- inputs=inputs
- )
- )
+ dependent_node = workflow_context.model.node.get_by_name(mock.models.DEPENDENT_NODE_NAME)
+ dependent_node.host_id = dependency_node.id
+ workflow_context.model.node.update(dependent_node)
- execute(workflow_func=basic_workflow, workflow_context=workflow_context, executor=executor)
+ dependent_node_instance = workflow_context.model.node_instance.get_by_name(
+ mock.models.DEPENDENT_NODE_INSTANCE_NAME)
+ dependent_node_instance.host_id = dependent_node_instance.id
+ workflow_context.model.node_instance.update(dependent_node_instance)
- assert global_test_holder.get('host_ip') == \
- dependency_node_instance.runtime_properties.get('ip')
+ relationship = workflow_context.model.relationship.list()[0]
+ relationship_instance = workflow_context.model.relationship_instance.list()[0]
+ return dependency_node, dependency_node_instance, dependent_node, dependent_node_instance, \
+ relationship, relationship_instance
-def test_dependent_node_instances(workflow_context, executor):
+def test_host_ip(workflow_context, executor):
operation_name = 'aria.interfaces.lifecycle.create'
- dependency_node, dependency_node_instance, _, dependent_node_instance, _, _ = \
- _create_simple_model_in_storage(workflow_context)
+ dependency_node, dependency_node_instance, _, _, _, _ = _get_elements(workflow_context)
dependency_node.operations[operation_name] = {
- 'operation': op_path(dependent_nodes, module_path=__name__)
+ 'operation': op_path(host_ip, module_path=__name__)
}
- workflow_context.model.node.store(dependency_node)
+ workflow_context.model.node.put(dependency_node)
inputs = {'putput': True}
@workflow
@@ -116,18 +94,18 @@ def test_dependent_node_instances(workflow_context, executor):
execute(workflow_func=basic_workflow, workflow_context=workflow_context, executor=executor)
- assert list(global_test_holder.get('dependent_node_instances', [])) == \
- list([dependent_node_instance])
+ assert global_test_holder.get('host_ip') == \
+ dependency_node_instance.runtime_properties.get('ip')
def test_relationship_tool_belt(workflow_context, executor):
operation_name = 'aria.interfaces.relationship_lifecycle.postconfigure'
_, _, _, _, relationship, relationship_instance = \
- _create_simple_model_in_storage(workflow_context)
+ _get_elements(workflow_context)
relationship.source_operations[operation_name] = {
'operation': op_path(relationship_operation, module_path=__name__)
}
- workflow_context.model.relationship.store(relationship)
+ workflow_context.model.relationship.put(relationship)
inputs = {'putput': True}
@@ -152,17 +130,13 @@ def test_wrong_model_toolbelt():
with pytest.raises(RuntimeError):
context.toolbelt(None)
+
@operation(toolbelt=True)
def host_ip(toolbelt, **_):
global_test_holder['host_ip'] = toolbelt.host_ip
@operation(toolbelt=True)
-def dependent_nodes(toolbelt, **_):
- global_test_holder['dependent_node_instances'] = list(toolbelt.dependent_node_instances)
-
-
-@operation(toolbelt=True)
def relationship_operation(ctx, toolbelt, **_):
global_test_holder[ctx.name] = toolbelt
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c6c92ae5/tests/orchestrator/context/test_workflow.py
----------------------------------------------------------------------
diff --git a/tests/orchestrator/context/test_workflow.py b/tests/orchestrator/context/test_workflow.py
index 258f0c5..496c1ff 100644
--- a/tests/orchestrator/context/test_workflow.py
+++ b/tests/orchestrator/context/test_workflow.py
@@ -19,20 +19,19 @@ import pytest
from aria import application_model_storage
from aria.orchestrator import context
-
+from aria.storage.sql_mapi import SQLAlchemyModelAPI
+from tests import storage as test_storage
from tests.mock import models
-from tests.storage import InMemoryModelDriver
class TestWorkflowContext(object):
def test_execution_creation_on_workflow_context_creation(self, storage):
- self._create_ctx(storage)
- execution = storage.execution.get(models.EXECUTION_ID)
- assert execution.id == models.EXECUTION_ID
- assert execution.deployment_id == models.DEPLOYMENT_ID
- assert execution.workflow_id == models.WORKFLOW_ID
- assert execution.blueprint_id == models.BLUEPRINT_ID
+ ctx = self._create_ctx(storage)
+ execution = storage.execution.get(ctx.execution.id) # pylint: disable=no-member
+ assert execution.deployment == storage.deployment.get_by_name(models.DEPLOYMENT_NAME)
+ assert execution.workflow_name == models.WORKFLOW_NAME
+ assert execution.blueprint == storage.blueprint.get_by_name(models.BLUEPRINT_NAME)
assert execution.status == storage.execution.model_cls.PENDING
assert execution.parameters == {}
assert execution.created_at <= datetime.utcnow()
@@ -43,13 +42,17 @@ class TestWorkflowContext(object):
@staticmethod
def _create_ctx(storage):
+ """
+
+ :param storage:
+ :return WorkflowContext:
+ """
return context.workflow.WorkflowContext(
name='simple_context',
model_storage=storage,
resource_storage=None,
- deployment_id=models.DEPLOYMENT_ID,
- workflow_id=models.WORKFLOW_ID,
- execution_id=models.EXECUTION_ID,
+ deployment_id=storage.deployment.get_by_name(models.DEPLOYMENT_NAME).id,
+ workflow_name=models.WORKFLOW_NAME,
task_max_attempts=models.TASK_MAX_ATTEMPTS,
task_retry_interval=models.TASK_RETRY_INTERVAL
)
@@ -57,8 +60,10 @@ class TestWorkflowContext(object):
@pytest.fixture(scope='function')
def storage():
- result = application_model_storage(InMemoryModelDriver())
- result.setup()
- result.blueprint.store(models.get_blueprint())
- result.deployment.store(models.get_deployment())
- return result
+ api_kwargs = test_storage.get_sqlite_api_kwargs()
+ workflow_storage = application_model_storage(SQLAlchemyModelAPI, api_kwargs=api_kwargs)
+ workflow_storage.blueprint.put(models.get_blueprint())
+ blueprint = workflow_storage.blueprint.get_by_name(models.BLUEPRINT_NAME)
+ workflow_storage.deployment.put(models.get_deployment(blueprint))
+ yield workflow_storage
+ test_storage.release_sqlite_storage(workflow_storage)