You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/06/21 14:38:01 UTC
incubator-airflow git commit: [AIRFLOW-1275] Put 'airflow pool' into
API
Repository: incubator-airflow
Updated Branches:
refs/heads/master a45e2d188 -> 9958aa9d5
[AIRFLOW-1275] Put 'airflow pool' into API
Closes #2346 from skudriashev/airflow-1275
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/9958aa9d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/9958aa9d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/9958aa9d
Branch: refs/heads/master
Commit: 9958aa9d5326b75cf7082c0bc36c13b063f1924f
Parents: a45e2d1
Author: Stanislav Kudriashev <st...@gmail.com>
Authored: Wed Jun 21 16:36:45 2017 +0200
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Wed Jun 21 16:36:51 2017 +0200
----------------------------------------------------------------------
airflow/api/client/api_client.py | 36 +-
airflow/api/client/json_client.py | 60 +++-
airflow/api/client/local_client.py | 20 +-
airflow/api/common/experimental/pool.py | 85 +++++
airflow/bin/cli.py | 52 ++-
airflow/models.py | 8 +
airflow/www/api/experimental/endpoints.py | 68 +++-
tests/api/__init__.py | 6 -
tests/api/client/local_client.py | 107 ------
tests/api/client/test_local_client.py | 144 ++++++++
tests/api/common/experimental/__init__.py | 13 +
tests/api/common/experimental/mark_tasks.py | 396 ++++++++++++++++++++++
tests/api/common/experimental/test_pool.py | 132 ++++++++
tests/api/common/mark_tasks.py | 396 ----------------------
tests/core.py | 59 +++-
tests/www/api/experimental/test_endpoints.py | 153 ++++++++-
16 files changed, 1155 insertions(+), 580 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/api/client/api_client.py
----------------------------------------------------------------------
diff --git a/airflow/api/client/api_client.py b/airflow/api/client/api_client.py
index 6a77538..f24d809 100644
--- a/airflow/api/client/api_client.py
+++ b/airflow/api/client/api_client.py
@@ -14,17 +14,47 @@
#
-class Client:
+class Client(object):
+ """Base API client for all API clients."""
+
def __init__(self, api_base_url, auth):
self._api_base_url = api_base_url
self._auth = auth
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
- """
- Creates a dag run for the specified dag
+ """Create a dag run for the specified dag.
+
:param dag_id:
:param run_id:
:param conf:
+ :param execution_date:
:return:
"""
raise NotImplementedError()
+
+ def get_pool(self, name):
+ """Get pool.
+
+ :param name: pool name
+ """
+ raise NotImplementedError()
+
+ def get_pools(self):
+ """Get all pools."""
+ raise NotImplementedError()
+
+ def create_pool(self, name, slots, description):
+ """Create a pool.
+
+ :param name: pool name
+ :param slots: pool slots amount
+ :param description: pool description
+ """
+ raise NotImplementedError()
+
+ def delete_pool(self, name):
+ """Delete pool.
+
+ :param name: pool name
+ """
+ raise NotImplementedError()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/api/client/json_client.py
----------------------------------------------------------------------
diff --git a/airflow/api/client/json_client.py b/airflow/api/client/json_client.py
index d74fc63..37e24d3 100644
--- a/airflow/api/client/json_client.py
+++ b/airflow/api/client/json_client.py
@@ -11,30 +11,70 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-#
+
from future.moves.urllib.parse import urljoin
+import requests
from airflow.api.client import api_client
-import requests
-
class Client(api_client.Client):
+ """Json API client implementation."""
+
+ def _request(self, url, method='GET', json=None):
+ params = {
+ 'url': url,
+ 'auth': self._auth,
+ }
+ if json is not None:
+ params['json'] = json
+
+ resp = getattr(requests, method.lower())(**params)
+ if not resp.ok:
+ try:
+ data = resp.json()
+ except Exception:
+ data = {}
+ raise IOError(data.get('error', 'Server error'))
+
+ return resp.json()
+
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
endpoint = '/api/experimental/dags/{}/dag_runs'.format(dag_id)
url = urljoin(self._api_base_url, endpoint)
-
- resp = requests.post(url,
- auth=self._auth,
+ data = self._request(url, method='POST',
json={
"run_id": run_id,
"conf": conf,
"execution_date": execution_date,
})
+ return data['message']
- if not resp.ok:
- raise IOError()
+ def get_pool(self, name):
+ endpoint = '/api/experimental/pools/{}'.format(name)
+ url = urljoin(self._api_base_url, endpoint)
+ pool = self._request(url)
+ return pool['pool'], pool['slots'], pool['description']
- data = resp.json()
+ def get_pools(self):
+ endpoint = '/api/experimental/pools'
+ url = urljoin(self._api_base_url, endpoint)
+ pools = self._request(url)
+ return [(p['pool'], p['slots'], p['description']) for p in pools]
- return data['message']
+ def create_pool(self, name, slots, description):
+ endpoint = '/api/experimental/pools'
+ url = urljoin(self._api_base_url, endpoint)
+ pool = self._request(url, method='POST',
+ json={
+ 'name': name,
+ 'slots': slots,
+ 'description': description,
+ })
+ return pool['pool'], pool['slots'], pool['description']
+
+ def delete_pool(self, name):
+ endpoint = '/api/experimental/pools/{}'.format(name)
+ url = urljoin(self._api_base_url, endpoint)
+ pool = self._request(url, method='DELETE')
+ return pool['pool'], pool['slots'], pool['description']
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/api/client/local_client.py
----------------------------------------------------------------------
diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py
index 05f27f6..5bc7f76 100644
--- a/airflow/api/client/local_client.py
+++ b/airflow/api/client/local_client.py
@@ -11,15 +11,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-#
+
from airflow.api.client import api_client
+from airflow.api.common.experimental import pool
from airflow.api.common.experimental import trigger_dag
class Client(api_client.Client):
+ """Local API client implementation."""
+
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
dr = trigger_dag.trigger_dag(dag_id=dag_id,
run_id=run_id,
conf=conf,
execution_date=execution_date)
return "Created {}".format(dr)
+
+ def get_pool(self, name):
+ p = pool.get_pool(name=name)
+ return p.pool, p.slots, p.description
+
+ def get_pools(self):
+ return [(p.pool, p.slots, p.description) for p in pool.get_pools()]
+
+ def create_pool(self, name, slots, description):
+ p = pool.create_pool(name=name, slots=slots, description=description)
+ return p.pool, p.slots, p.description
+
+ def delete_pool(self, name):
+ p = pool.delete_pool(name=name)
+ return p.pool, p.slots, p.description
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/api/common/experimental/pool.py
----------------------------------------------------------------------
diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py
new file mode 100644
index 0000000..6e963a2
--- /dev/null
+++ b/airflow/api/common/experimental/pool.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from airflow.exceptions import AirflowException
+from airflow.models import Pool
+from airflow.utils.db import provide_session
+
+
+class PoolBadRequest(AirflowException):
+ status = 400
+
+
+class PoolNotFound(AirflowException):
+ status = 404
+
+
+@provide_session
+def get_pool(name, session=None):
+ """Get pool by a given name."""
+ if not (name and name.strip()):
+ raise PoolBadRequest("Pool name shouldn't be empty")
+
+ pool = session.query(Pool).filter_by(pool=name).first()
+ if pool is None:
+ raise PoolNotFound("Pool '%s' doesn't exist" % name)
+
+ return pool
+
+
+@provide_session
+def get_pools(session=None):
+ """Get all pools."""
+ return session.query(Pool).all()
+
+
+@provide_session
+def create_pool(name, slots, description, session=None):
+ """Create a pool with a given parameters."""
+ if not (name and name.strip()):
+ raise PoolBadRequest("Pool name shouldn't be empty")
+
+ try:
+ slots = int(slots)
+ except ValueError:
+ raise PoolBadRequest("Bad value for `slots`: %s" % slots)
+
+ session.expire_on_commit = False
+ pool = session.query(Pool).filter_by(pool=name).first()
+ if pool is None:
+ pool = Pool(pool=name, slots=slots, description=description)
+ session.add(pool)
+ else:
+ pool.slots = slots
+ pool.description = description
+
+ session.commit()
+
+ return pool
+
+
+@provide_session
+def delete_pool(name, session=None):
+ """Delete pool by a given name."""
+ if not (name and name.strip()):
+ raise PoolBadRequest("Pool name shouldn't be empty")
+
+ pool = session.query(Pool).filter_by(pool=name).first()
+ if pool is None:
+ raise PoolNotFound("Pool '%s' doesn't exist" % name)
+
+ session.delete(pool)
+ session.commit()
+
+ return pool
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/bin/cli.py
----------------------------------------------------------------------
diff --git a/airflow/bin/cli.py b/airflow/bin/cli.py
index 41f979f..4b3a0ed 100755
--- a/airflow/bin/cli.py
+++ b/airflow/bin/cli.py
@@ -49,7 +49,7 @@ from airflow.exceptions import AirflowException
from airflow.executors import GetDefaultExecutor
from airflow.models import (DagModel, DagBag, TaskInstance,
DagPickle, DagRun, Variable, DagStat,
- Pool, Connection)
+ Connection)
from airflow.ti_deps.dep_context import (DepContext, SCHEDULER_DEPS)
from airflow.utils import db as db_utils
from airflow.utils import logging as logging_utils
@@ -187,40 +187,28 @@ def trigger_dag(args):
def pool(args):
- session = settings.Session()
- if args.get or (args.set and args.set[0]) or args.delete:
- name = args.get or args.delete or args.set[0]
- pool = (
- session.query(Pool)
- .filter(Pool.pool == name)
- .first())
- if pool and args.get:
- print("{} ".format(pool))
- return
- elif not pool and (args.get or args.delete):
- print("No pool named {} found".format(name))
- elif not pool and args.set:
- pool = Pool(
- pool=name,
- slots=args.set[1],
- description=args.set[2])
- session.add(pool)
- session.commit()
- print("{} ".format(pool))
- elif pool and args.set:
- pool.slots = args.set[1]
- pool.description = args.set[2]
- session.commit()
- print("{} ".format(pool))
- return
- elif pool and args.delete:
- session.query(Pool).filter_by(pool=args.delete).delete()
- session.commit()
- print("Pool {} deleted".format(name))
+ def _tabulate(pools):
+ return "\n%s" % tabulate(pools, ['Pool', 'Slots', 'Description'],
+ tablefmt="fancy_grid")
+ try:
+ if args.get is not None:
+ pools = [api_client.get_pool(name=args.get)]
+ elif args.set:
+ pools = [api_client.create_pool(name=args.set[0],
+ slots=args.set[1],
+ description=args.set[2])]
+ elif args.delete:
+ pools = [api_client.delete_pool(name=args.delete)]
+ else:
+ pools = api_client.get_pools()
+ except (AirflowException, IOError) as err:
+ logging.error(err)
+ else:
+ logging.info(_tabulate(pools=pools))
-def variables(args):
+def variables(args):
if args.get:
try:
var = Variable.get(args.get,
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 2c433ad..0002572 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -4395,6 +4395,14 @@ class Pool(Base):
def __repr__(self):
return self.pool
+ def to_json(self):
+ return {
+ 'id': self.id,
+ 'pool': self.pool,
+ 'slots': self.slots,
+ 'description': self.description,
+ }
+
@provide_session
def used_slots(self, session):
"""
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/airflow/www/api/experimental/endpoints.py
----------------------------------------------------------------------
diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py
index be92735..a8d7f5c 100644
--- a/airflow/www/api/experimental/endpoints.py
+++ b/airflow/www/api/experimental/endpoints.py
@@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
import airflow.api
+from airflow.api.common.experimental import pool as pool_api
from airflow.api.common.experimental import trigger_dag as trigger
from airflow.api.common.experimental.get_task import get_task
from airflow.api.common.experimental.get_task_instance import get_task_instance
@@ -96,7 +98,6 @@ def test():
@requires_authentication
def task_info(dag_id, task_id):
"""Returns a JSON with a task's public instance variables. """
-
try:
info = get_task(dag_id, task_id)
except AirflowException as err:
@@ -169,4 +170,67 @@ def latest_dag_runs():
'dag_run_url': url_for('airflow.graph', dag_id=dagrun.dag_id,
execution_date=dagrun.execution_date)
})
- return jsonify(items=payload) # old flask versions dont support jsonifying arrays
+ return jsonify(items=payload) # old flask versions dont support jsonifying arrays
+
+
+@api_experimental.route('/pools/<string:name>', methods=['GET'])
+@requires_authentication
+def get_pool(name):
+ """Get pool by a given name."""
+ try:
+ pool = pool_api.get_pool(name=name)
+ except AirflowException as e:
+ _log.error(e)
+ response = jsonify(error="{}".format(e))
+ response.status_code = getattr(e, 'status', 500)
+ return response
+ else:
+ return jsonify(pool.to_json())
+
+
+@api_experimental.route('/pools', methods=['GET'])
+@requires_authentication
+def get_pools():
+ """Get all pools."""
+ try:
+ pools = pool_api.get_pools()
+ except AirflowException as e:
+ _log.error(e)
+ response = jsonify(error="{}".format(e))
+ response.status_code = getattr(e, 'status', 500)
+ return response
+ else:
+ return jsonify([p.to_json() for p in pools])
+
+
+@csrf.exempt
+@api_experimental.route('/pools', methods=['POST'])
+@requires_authentication
+def create_pool():
+ """Create a pool."""
+ params = request.get_json(force=True)
+ try:
+ pool = pool_api.create_pool(**params)
+ except AirflowException as e:
+ _log.error(e)
+ response = jsonify(error="{}".format(e))
+ response.status_code = getattr(e, 'status', 500)
+ return response
+ else:
+ return jsonify(pool.to_json())
+
+
+@csrf.exempt
+@api_experimental.route('/pools/<string:name>', methods=['DELETE'])
+@requires_authentication
+def delete_pool(name):
+ """Delete pool."""
+ try:
+ pool = pool_api.delete_pool(name=name)
+ except AirflowException as e:
+ _log.error(e)
+ response = jsonify(error="{}".format(e))
+ response.status_code = getattr(e, 'status', 500)
+ return response
+ else:
+ return jsonify(pool.to_json())
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/__init__.py
----------------------------------------------------------------------
diff --git a/tests/api/__init__.py b/tests/api/__init__.py
index 37d59f0..9d7677a 100644
--- a/tests/api/__init__.py
+++ b/tests/api/__init__.py
@@ -11,9 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from __future__ import absolute_import
-
-from .client import *
-from .common import *
-
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/client/local_client.py
----------------------------------------------------------------------
diff --git a/tests/api/client/local_client.py b/tests/api/client/local_client.py
deleted file mode 100644
index a36b71f..0000000
--- a/tests/api/client/local_client.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import unittest
-import datetime
-
-from mock import patch
-
-from airflow import AirflowException
-from airflow import models
-
-from airflow.api.client.local_client import Client
-from airflow.utils.state import State
-
-EXECDATE = datetime.datetime.now()
-EXECDATE_NOFRACTIONS = EXECDATE.replace(microsecond=0)
-EXECDATE_ISO = EXECDATE_NOFRACTIONS.isoformat()
-
-real_datetime_class = datetime.datetime
-
-
-def mock_datetime_now(target, dt):
- class DatetimeSubclassMeta(type):
- @classmethod
- def __instancecheck__(mcs, obj):
- return isinstance(obj, real_datetime_class)
-
- class BaseMockedDatetime(real_datetime_class):
- @classmethod
- def now(cls, tz=None):
- return target.replace(tzinfo=tz)
-
- @classmethod
- def utcnow(cls):
- return target
-
- # Python2 & Python3 compatible metaclass
- MockedDatetime = DatetimeSubclassMeta('datetime', (BaseMockedDatetime,), {})
-
- return patch.object(dt, 'datetime', MockedDatetime)
-
-
-class TestLocalClient(unittest.TestCase):
- def setUp(self):
- self.client = Client(api_base_url=None, auth=None)
-
- @patch.object(models.DAG, 'create_dagrun')
- def test_trigger_dag(self, mock):
- client = self.client
-
- # non existent
- with self.assertRaises(AirflowException):
- client.trigger_dag(dag_id="blablabla")
-
- import airflow.api.common.experimental.trigger_dag
- with mock_datetime_now(EXECDATE, airflow.api.common.experimental.trigger_dag.datetime):
- # no execution date, execution date should be set automatically
- client.trigger_dag(dag_id="test_start_date_scheduling")
- mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO),
- execution_date=EXECDATE_NOFRACTIONS,
- state=State.RUNNING,
- conf=None,
- external_trigger=True)
- mock.reset_mock()
-
- # execution date with microseconds cutoff
- client.trigger_dag(dag_id="test_start_date_scheduling", execution_date=EXECDATE)
- mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO),
- execution_date=EXECDATE_NOFRACTIONS,
- state=State.RUNNING,
- conf=None,
- external_trigger=True)
- mock.reset_mock()
-
- # run id
- run_id = "my_run_id"
- client.trigger_dag(dag_id="test_start_date_scheduling", run_id=run_id)
- mock.assert_called_once_with(run_id=run_id,
- execution_date=EXECDATE_NOFRACTIONS,
- state=State.RUNNING,
- conf=None,
- external_trigger=True)
- mock.reset_mock()
-
- # test conf
- conf = '{"name": "John"}'
- client.trigger_dag(dag_id="test_start_date_scheduling", conf=conf)
- mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO),
- execution_date=EXECDATE_NOFRACTIONS,
- state=State.RUNNING,
- conf=json.loads(conf),
- external_trigger=True)
- mock.reset_mock()
-
- # this is a unit test only, cannot verify existing dag run
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/client/test_local_client.py
----------------------------------------------------------------------
diff --git a/tests/api/client/test_local_client.py b/tests/api/client/test_local_client.py
new file mode 100644
index 0000000..7a759fe
--- /dev/null
+++ b/tests/api/client/test_local_client.py
@@ -0,0 +1,144 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+import json
+import unittest
+
+from mock import patch
+
+from airflow import AirflowException
+from airflow.api.client.local_client import Client
+from airflow import models
+from airflow import settings
+from airflow.utils.state import State
+
+EXECDATE = datetime.datetime.now()
+EXECDATE_NOFRACTIONS = EXECDATE.replace(microsecond=0)
+EXECDATE_ISO = EXECDATE_NOFRACTIONS.isoformat()
+
+real_datetime_class = datetime.datetime
+
+
+def mock_datetime_now(target, dt):
+ class DatetimeSubclassMeta(type):
+ @classmethod
+ def __instancecheck__(mcs, obj):
+ return isinstance(obj, real_datetime_class)
+
+ class BaseMockedDatetime(real_datetime_class):
+ @classmethod
+ def now(cls, tz=None):
+ return target.replace(tzinfo=tz)
+
+ @classmethod
+ def utcnow(cls):
+ return target
+
+ # Python2 & Python3 compatible metaclass
+ MockedDatetime = DatetimeSubclassMeta('datetime', (BaseMockedDatetime,), {})
+
+ return patch.object(dt, 'datetime', MockedDatetime)
+
+
+class TestLocalClient(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestLocalClient, cls).setUpClass()
+ session = settings.Session()
+ session.query(models.Pool).delete()
+ session.commit()
+ session.close()
+
+ def setUp(self):
+ super(TestLocalClient, self).setUp()
+ self.client = Client(api_base_url=None, auth=None)
+ self.session = settings.Session()
+
+ def tearDown(self):
+ self.session.query(models.Pool).delete()
+ self.session.commit()
+ self.session.close()
+ super(TestLocalClient, self).tearDown()
+
+ @patch.object(models.DAG, 'create_dagrun')
+ def test_trigger_dag(self, mock):
+ client = self.client
+
+ # non existent
+ with self.assertRaises(AirflowException):
+ client.trigger_dag(dag_id="blablabla")
+
+ import airflow.api.common.experimental.trigger_dag
+ with mock_datetime_now(EXECDATE, airflow.api.common.experimental.trigger_dag.datetime):
+ # no execution date, execution date should be set automatically
+ client.trigger_dag(dag_id="test_start_date_scheduling")
+ mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO),
+ execution_date=EXECDATE_NOFRACTIONS,
+ state=State.RUNNING,
+ conf=None,
+ external_trigger=True)
+ mock.reset_mock()
+
+ # execution date with microseconds cutoff
+ client.trigger_dag(dag_id="test_start_date_scheduling", execution_date=EXECDATE)
+ mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO),
+ execution_date=EXECDATE_NOFRACTIONS,
+ state=State.RUNNING,
+ conf=None,
+ external_trigger=True)
+ mock.reset_mock()
+
+ # run id
+ run_id = "my_run_id"
+ client.trigger_dag(dag_id="test_start_date_scheduling", run_id=run_id)
+ mock.assert_called_once_with(run_id=run_id,
+ execution_date=EXECDATE_NOFRACTIONS,
+ state=State.RUNNING,
+ conf=None,
+ external_trigger=True)
+ mock.reset_mock()
+
+ # test conf
+ conf = '{"name": "John"}'
+ client.trigger_dag(dag_id="test_start_date_scheduling", conf=conf)
+ mock.assert_called_once_with(run_id="manual__{0}".format(EXECDATE_ISO),
+ execution_date=EXECDATE_NOFRACTIONS,
+ state=State.RUNNING,
+ conf=json.loads(conf),
+ external_trigger=True)
+ mock.reset_mock()
+
+ def test_get_pool(self):
+ self.client.create_pool(name='foo', slots=1, description='')
+ pool = self.client.get_pool(name='foo')
+ self.assertEqual(pool, ('foo', 1, ''))
+
+ def test_get_pools(self):
+ self.client.create_pool(name='foo1', slots=1, description='')
+ self.client.create_pool(name='foo2', slots=2, description='')
+ pools = sorted(self.client.get_pools(), key=lambda p: p[0])
+ self.assertEqual(pools, [('foo1', 1, ''), ('foo2', 2, '')])
+
+ def test_create_pool(self):
+ pool = self.client.create_pool(name='foo', slots=1, description='')
+ self.assertEqual(pool, ('foo', 1, ''))
+ self.assertEqual(self.session.query(models.Pool).count(), 1)
+
+ def test_delete_pool(self):
+ self.client.create_pool(name='foo', slots=1, description='')
+ self.assertEqual(self.session.query(models.Pool).count(), 1)
+ self.client.delete_pool(name='foo')
+ self.assertEqual(self.session.query(models.Pool).count(), 0)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/common/experimental/__init__.py
----------------------------------------------------------------------
diff --git a/tests/api/common/experimental/__init__.py b/tests/api/common/experimental/__init__.py
new file mode 100644
index 0000000..9d7677a
--- /dev/null
+++ b/tests/api/common/experimental/__init__.py
@@ -0,0 +1,13 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/common/experimental/mark_tasks.py
----------------------------------------------------------------------
diff --git a/tests/api/common/experimental/mark_tasks.py b/tests/api/common/experimental/mark_tasks.py
new file mode 100644
index 0000000..e4395ae
--- /dev/null
+++ b/tests/api/common/experimental/mark_tasks.py
@@ -0,0 +1,396 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from airflow import models
+from airflow.api.common.experimental.mark_tasks import (
+ set_state, _create_dagruns, set_dag_run_state)
+from airflow.settings import Session
+from airflow.utils.dates import days_ago
+from airflow.utils.state import State
+from datetime import datetime, timedelta
+
+DEV_NULL = "/dev/null"
+
+
+class TestMarkTasks(unittest.TestCase):
+
+ def setUp(self):
+ self.dagbag = models.DagBag(include_examples=True)
+ self.dag1 = self.dagbag.dags['test_example_bash_operator']
+ self.dag2 = self.dagbag.dags['example_subdag_operator']
+
+ self.execution_dates = [days_ago(2), days_ago(1)]
+
+ drs = _create_dagruns(self.dag1, self.execution_dates,
+ state=State.RUNNING,
+ run_id_template="scheduled__{}")
+ for dr in drs:
+ dr.dag = self.dag1
+ dr.verify_integrity()
+
+ drs = _create_dagruns(self.dag2,
+ [self.dag2.default_args['start_date']],
+ state=State.RUNNING,
+ run_id_template="scheduled__{}")
+
+ for dr in drs:
+ dr.dag = self.dag2
+ dr.verify_integrity()
+
+ self.session = Session()
+
+ def tearDown(self):
+ self.dag1.clear()
+ self.dag2.clear()
+
+ # just to make sure we are fully cleaned up
+ self.session.query(models.DagRun).delete()
+ self.session.query(models.TaskInstance).delete()
+ self.session.commit()
+ self.session.close()
+
+ def snapshot_state(self, dag, execution_dates):
+ TI = models.TaskInstance
+ tis = self.session.query(TI).filter(
+ TI.dag_id==dag.dag_id,
+ TI.execution_date.in_(execution_dates)
+ ).all()
+
+ self.session.expunge_all()
+
+ return tis
+
+ def verify_state(self, dag, task_ids, execution_dates, state, old_tis):
+ TI = models.TaskInstance
+
+ tis = self.session.query(TI).filter(
+ TI.dag_id==dag.dag_id,
+ TI.execution_date.in_(execution_dates)
+ ).all()
+
+ self.assertTrue(len(tis) > 0)
+
+ for ti in tis:
+ if ti.task_id in task_ids and ti.execution_date in execution_dates:
+ self.assertEqual(ti.state, state)
+ else:
+ for old_ti in old_tis:
+ if (old_ti.task_id == ti.task_id
+ and old_ti.execution_date == ti.execution_date):
+ self.assertEqual(ti.state, old_ti.state)
+
+ def test_mark_tasks_now(self):
+ # set one task to success but do not commit
+ snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+ task = self.dag1.get_task("runme_1")
+ altered = set_state(task=task, execution_date=self.execution_dates[0],
+ upstream=False, downstream=False, future=False,
+ past=False, state=State.SUCCESS, commit=False)
+ self.assertEqual(len(altered), 1)
+ self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
+ None, snapshot)
+
+ # set one and only one task to success
+ altered = set_state(task=task, execution_date=self.execution_dates[0],
+ upstream=False, downstream=False, future=False,
+ past=False, state=State.SUCCESS, commit=True)
+ self.assertEqual(len(altered), 1)
+ self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
+ State.SUCCESS, snapshot)
+
+ # set no tasks
+ altered = set_state(task=task, execution_date=self.execution_dates[0],
+ upstream=False, downstream=False, future=False,
+ past=False, state=State.SUCCESS, commit=True)
+ self.assertEqual(len(altered), 0)
+ self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
+ State.SUCCESS, snapshot)
+
+ # set task to other than success
+ altered = set_state(task=task, execution_date=self.execution_dates[0],
+ upstream=False, downstream=False, future=False,
+ past=False, state=State.FAILED, commit=True)
+ self.assertEqual(len(altered), 1)
+ self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
+ State.FAILED, snapshot)
+
+ # dont alter other tasks
+ snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+ task = self.dag1.get_task("runme_0")
+ altered = set_state(task=task, execution_date=self.execution_dates[0],
+ upstream=False, downstream=False, future=False,
+ past=False, state=State.SUCCESS, commit=True)
+ self.assertEqual(len(altered), 1)
+ self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
+ State.SUCCESS, snapshot)
+
+ def test_mark_downstream(self):
+ # test downstream
+ snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+ task = self.dag1.get_task("runme_1")
+ relatives = task.get_flat_relatives(upstream=False)
+ task_ids = [t.task_id for t in relatives]
+ task_ids.append(task.task_id)
+
+ altered = set_state(task=task, execution_date=self.execution_dates[0],
+ upstream=False, downstream=True, future=False,
+ past=False, state=State.SUCCESS, commit=True)
+ self.assertEqual(len(altered), 3)
+ self.verify_state(self.dag1, task_ids, [self.execution_dates[0]],
+ State.SUCCESS, snapshot)
+
+ def test_mark_upstream(self):
+ # test upstream
+ snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+ task = self.dag1.get_task("run_after_loop")
+ relatives = task.get_flat_relatives(upstream=True)
+ task_ids = [t.task_id for t in relatives]
+ task_ids.append(task.task_id)
+
+ altered = set_state(task=task, execution_date=self.execution_dates[0],
+ upstream=True, downstream=False, future=False,
+ past=False, state=State.SUCCESS, commit=True)
+ self.assertEqual(len(altered), 4)
+ self.verify_state(self.dag1, task_ids, [self.execution_dates[0]],
+ State.SUCCESS, snapshot)
+
+ def test_mark_tasks_future(self):
+ # set one task to success towards end of scheduled dag runs
+ snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+ task = self.dag1.get_task("runme_1")
+ altered = set_state(task=task, execution_date=self.execution_dates[0],
+ upstream=False, downstream=False, future=True,
+ past=False, state=State.SUCCESS, commit=True)
+ self.assertEqual(len(altered), 2)
+ self.verify_state(self.dag1, [task.task_id], self.execution_dates,
+ State.SUCCESS, snapshot)
+
+ def test_mark_tasks_past(self):
+ # set one task to success towards end of scheduled dag runs
+ snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+ task = self.dag1.get_task("runme_1")
+ altered = set_state(task=task, execution_date=self.execution_dates[1],
+ upstream=False, downstream=False, future=False,
+ past=True, state=State.SUCCESS, commit=True)
+ self.assertEqual(len(altered), 2)
+ self.verify_state(self.dag1, [task.task_id], self.execution_dates,
+ State.SUCCESS, snapshot)
+
+ def test_mark_tasks_subdag(self):
+ # set one task to success towards end of scheduled dag runs
+ task = self.dag2.get_task("section-1")
+ relatives = task.get_flat_relatives(upstream=False)
+ task_ids = [t.task_id for t in relatives]
+ task_ids.append(task.task_id)
+
+ altered = set_state(task=task, execution_date=self.execution_dates[0],
+ upstream=False, downstream=True, future=False,
+ past=False, state=State.SUCCESS, commit=True)
+ self.assertEqual(len(altered), 14)
+
+ # cannot use snapshot here as that will require drilling down the
+ # the sub dag tree essentially recreating the same code as in the
+ # tested logic.
+ self.verify_state(self.dag2, task_ids, [self.execution_dates[0]],
+ State.SUCCESS, [])
+
+
+class TestMarkDAGRun(unittest.TestCase):
+ def setUp(self):
+ self.dagbag = models.DagBag(include_examples=True)
+ self.dag1 = self.dagbag.dags['test_example_bash_operator']
+ self.dag2 = self.dagbag.dags['example_subdag_operator']
+
+ self.execution_dates = [days_ago(3), days_ago(2), days_ago(1)]
+
+ self.session = Session()
+
+ def verify_dag_run_states(self, dag, date, state=State.SUCCESS):
+ drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date)
+ dr = drs[0]
+ self.assertEqual(dr.get_state(), state)
+ tis = dr.get_task_instances(session=self.session)
+ for ti in tis:
+ self.assertEqual(ti.state, state)
+
+ def test_set_running_dag_run_state(self):
+ date = self.execution_dates[0]
+ dr = self.dag1.create_dagrun(
+ run_id='manual__' + datetime.now().isoformat(),
+ state=State.RUNNING,
+ execution_date=date,
+ session=self.session
+ )
+ for ti in dr.get_task_instances(session=self.session):
+ ti.set_state(State.RUNNING, self.session)
+
+ altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True)
+
+ # All of the task should be altered
+ self.assertEqual(len(altered), len(self.dag1.tasks))
+ self.verify_dag_run_states(self.dag1, date)
+
+ def test_set_success_dag_run_state(self):
+ date = self.execution_dates[0]
+
+ dr = self.dag1.create_dagrun(
+ run_id='manual__' + datetime.now().isoformat(),
+ state=State.SUCCESS,
+ execution_date=date,
+ session=self.session
+ )
+ for ti in dr.get_task_instances(session=self.session):
+ ti.set_state(State.SUCCESS, self.session)
+
+ altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True)
+
+ # None of the task should be altered
+ self.assertEqual(len(altered), 0)
+ self.verify_dag_run_states(self.dag1, date)
+
+ def test_set_failed_dag_run_state(self):
+ date = self.execution_dates[0]
+ dr = self.dag1.create_dagrun(
+ run_id='manual__' + datetime.now().isoformat(),
+ state=State.FAILED,
+ execution_date=date,
+ session=self.session
+ )
+ dr.get_task_instance('runme_0').set_state(State.FAILED, self.session)
+
+ altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True)
+
+ # All of the task should be altered
+ self.assertEqual(len(altered), len(self.dag1.tasks))
+ self.verify_dag_run_states(self.dag1, date)
+
+ def test_set_mixed_dag_run_state(self):
+ """
+ This test checks function set_dag_run_state with mixed task instance
+ state.
+ """
+ date = self.execution_dates[0]
+ dr = self.dag1.create_dagrun(
+ run_id='manual__' + datetime.now().isoformat(),
+ state=State.FAILED,
+ execution_date=date,
+ session=self.session
+ )
+ # success task
+ dr.get_task_instance('runme_0').set_state(State.SUCCESS, self.session)
+ # skipped task
+ dr.get_task_instance('runme_1').set_state(State.SKIPPED, self.session)
+ # retry task
+ dr.get_task_instance('runme_2').set_state(State.UP_FOR_RETRY, self.session)
+ # queued task
+ dr.get_task_instance('also_run_this').set_state(State.QUEUED, self.session)
+ # running task
+ dr.get_task_instance('run_after_loop').set_state(State.RUNNING, self.session)
+ # failed task
+ dr.get_task_instance('run_this_last').set_state(State.FAILED, self.session)
+
+ altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True)
+
+ self.assertEqual(len(altered), len(self.dag1.tasks) - 1) # only 1 task succeeded
+ self.verify_dag_run_states(self.dag1, date)
+
+ def test_set_state_without_commit(self):
+ date = self.execution_dates[0]
+
+ # Running dag run and task instances
+ dr = self.dag1.create_dagrun(
+ run_id='manual__' + datetime.now().isoformat(),
+ state=State.RUNNING,
+ execution_date=date,
+ session=self.session
+ )
+ for ti in dr.get_task_instances(session=self.session):
+ ti.set_state(State.RUNNING, self.session)
+
+ altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=False)
+
+ # All of the task should be altered
+ self.assertEqual(len(altered), len(self.dag1.tasks))
+
+ # Both dag run and task instances' states should remain the same
+ self.verify_dag_run_states(self.dag1, date, State.RUNNING)
+
+ def test_set_state_with_multiple_dagruns(self):
+ dr1 = self.dag2.create_dagrun(
+ run_id='manual__' + datetime.now().isoformat(),
+ state=State.FAILED,
+ execution_date=self.execution_dates[0],
+ session=self.session
+ )
+ dr2 = self.dag2.create_dagrun(
+ run_id='manual__' + datetime.now().isoformat(),
+ state=State.FAILED,
+ execution_date=self.execution_dates[1],
+ session=self.session
+ )
+ dr3 = self.dag2.create_dagrun(
+ run_id='manual__' + datetime.now().isoformat(),
+ state=State.RUNNING,
+ execution_date=self.execution_dates[2],
+ session=self.session
+ )
+
+ altered = set_dag_run_state(self.dag2, self.execution_dates[1],
+ state=State.SUCCESS, commit=True)
+
+ # Recursively count number of tasks in the dag
+ def count_dag_tasks(dag):
+ count = len(dag.tasks)
+ subdag_counts = [count_dag_tasks(subdag) for subdag in dag.subdags]
+ count += sum(subdag_counts)
+ return count
+
+ self.assertEqual(len(altered), count_dag_tasks(self.dag2))
+ self.verify_dag_run_states(self.dag2, self.execution_dates[1])
+
+ # Make sure other dag status are not changed
+ dr1 = models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[0])
+ dr1 = dr1[0]
+ self.assertEqual(dr1.get_state(), State.FAILED)
+ dr3 = models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[2])
+ dr3 = dr3[0]
+ self.assertEqual(dr3.get_state(), State.RUNNING)
+
+ def test_set_dag_run_state_edge_cases(self):
+ # Dag does not exist
+ altered = set_dag_run_state(None, self.execution_dates[0])
+ self.assertEqual(len(altered), 0)
+
+ # Invalid execution date
+ altered = set_dag_run_state(self.dag1, None)
+ self.assertEqual(len(altered), 0)
+ self.assertRaises(AssertionError, set_dag_run_state, self.dag1, timedelta(microseconds=-1))
+
+ # DagRun does not exist
+ # This will throw AssertionError since dag.latest_execution_date does not exist
+ self.assertRaises(AssertionError, set_dag_run_state, self.dag1, self.execution_dates[0])
+
+ def tearDown(self):
+ self.dag1.clear()
+ self.dag2.clear()
+
+ self.session.query(models.DagRun).delete()
+ self.session.query(models.TaskInstance).delete()
+ self.session.query(models.DagStat).delete()
+ self.session.commit()
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/common/experimental/test_pool.py
----------------------------------------------------------------------
diff --git a/tests/api/common/experimental/test_pool.py b/tests/api/common/experimental/test_pool.py
new file mode 100644
index 0000000..98969b8
--- /dev/null
+++ b/tests/api/common/experimental/test_pool.py
@@ -0,0 +1,132 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from airflow.api.common.experimental import pool as pool_api
+from airflow import models
+from airflow import settings
+
+
+class TestPool(unittest.TestCase):
+
+ def setUp(self):
+ super(TestPool, self).setUp()
+ self.session = settings.Session()
+ self.pools = []
+ for i in range(2):
+ name = 'experimental_%s' % (i + 1)
+ pool = models.Pool(
+ pool=name,
+ slots=i,
+ description=name,
+ )
+ self.session.add(pool)
+ self.pools.append(pool)
+ self.session.commit()
+
+ def tearDown(self):
+ self.session.query(models.Pool).delete()
+ self.session.commit()
+ self.session.close()
+ super(TestPool, self).tearDown()
+
+ def test_get_pool(self):
+ pool = pool_api.get_pool(name=self.pools[0].pool, session=self.session)
+ self.assertEqual(pool.pool, self.pools[0].pool)
+
+ def test_get_pool_non_existing(self):
+ self.assertRaisesRegexp(pool_api.PoolNotFound,
+ "^Pool 'test' doesn't exist$",
+ pool_api.get_pool,
+ name='test',
+ session=self.session)
+
+ def test_get_pool_bad_name(self):
+ for name in ('', ' '):
+ self.assertRaisesRegexp(pool_api.PoolBadRequest,
+ "^Pool name shouldn't be empty$",
+ pool_api.get_pool,
+ name=name,
+ session=self.session)
+
+ def test_get_pools(self):
+ pools = sorted(pool_api.get_pools(session=self.session),
+ key=lambda p: p.pool)
+ self.assertEqual(pools[0].pool, self.pools[0].pool)
+ self.assertEqual(pools[1].pool, self.pools[1].pool)
+
+ def test_create_pool(self):
+ pool = pool_api.create_pool(name='foo',
+ slots=5,
+ description='',
+ session=self.session)
+ self.assertEqual(pool.pool, 'foo')
+ self.assertEqual(pool.slots, 5)
+ self.assertEqual(pool.description, '')
+ self.assertEqual(self.session.query(models.Pool).count(), 3)
+
+ def test_create_pool_existing(self):
+ pool = pool_api.create_pool(name=self.pools[0].pool,
+ slots=5,
+ description='',
+ session=self.session)
+ self.assertEqual(pool.pool, self.pools[0].pool)
+ self.assertEqual(pool.slots, 5)
+ self.assertEqual(pool.description, '')
+ self.assertEqual(self.session.query(models.Pool).count(), 2)
+
+ def test_create_pool_bad_name(self):
+ for name in ('', ' '):
+ self.assertRaisesRegexp(pool_api.PoolBadRequest,
+ "^Pool name shouldn't be empty$",
+ pool_api.create_pool,
+ name=name,
+ slots=5,
+ description='',
+ session=self.session)
+
+ def test_create_pool_bad_slots(self):
+ self.assertRaisesRegexp(pool_api.PoolBadRequest,
+ "^Bad value for `slots`: foo$",
+ pool_api.create_pool,
+ name='foo',
+ slots='foo',
+ description='',
+ session=self.session)
+
+ def test_delete_pool(self):
+ pool = pool_api.delete_pool(name=self.pools[0].pool,
+ session=self.session)
+ self.assertEqual(pool.pool, self.pools[0].pool)
+ self.assertEqual(self.session.query(models.Pool).count(), 1)
+
+ def test_delete_pool_non_existing(self):
+ self.assertRaisesRegexp(pool_api.PoolNotFound,
+ "^Pool 'test' doesn't exist$",
+ pool_api.delete_pool,
+ name='test',
+ session=self.session)
+
+ def test_delete_pool_bad_name(self):
+ for name in ('', ' '):
+ self.assertRaisesRegexp(pool_api.PoolBadRequest,
+ "^Pool name shouldn't be empty$",
+ pool_api.delete_pool,
+ name=name,
+ session=self.session)
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/api/common/mark_tasks.py
----------------------------------------------------------------------
diff --git a/tests/api/common/mark_tasks.py b/tests/api/common/mark_tasks.py
deleted file mode 100644
index 8a3759f..0000000
--- a/tests/api/common/mark_tasks.py
+++ /dev/null
@@ -1,396 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import unittest
-
-from airflow import models
-from airflow.api.common.experimental.mark_tasks import (
- set_state, _create_dagruns, set_dag_run_state)
-from airflow.settings import Session
-from airflow.utils.dates import days_ago
-from airflow.utils.state import State
-from datetime import datetime, timedelta
-
-DEV_NULL = "/dev/null"
-
-
-class TestMarkTasks(unittest.TestCase):
- def setUp(self):
- self.dagbag = models.DagBag(include_examples=True)
- self.dag1 = self.dagbag.dags['test_example_bash_operator']
- self.dag2 = self.dagbag.dags['example_subdag_operator']
-
- self.execution_dates = [days_ago(2), days_ago(1)]
-
- drs = _create_dagruns(self.dag1, self.execution_dates,
- state=State.RUNNING,
- run_id_template="scheduled__{}")
- for dr in drs:
- dr.dag = self.dag1
- dr.verify_integrity()
-
- drs = _create_dagruns(self.dag2,
- [self.dag2.default_args['start_date']],
- state=State.RUNNING,
- run_id_template="scheduled__{}")
-
- for dr in drs:
- dr.dag = self.dag2
- dr.verify_integrity()
-
- self.session = Session()
-
- def snapshot_state(self, dag, execution_dates):
- TI = models.TaskInstance
- tis = self.session.query(TI).filter(
- TI.dag_id==dag.dag_id,
- TI.execution_date.in_(execution_dates)
- ).all()
-
- self.session.expunge_all()
-
- return tis
-
- def verify_state(self, dag, task_ids, execution_dates, state, old_tis):
- TI = models.TaskInstance
-
- tis = self.session.query(TI).filter(
- TI.dag_id==dag.dag_id,
- TI.execution_date.in_(execution_dates)
- ).all()
-
- self.assertTrue(len(tis) > 0)
-
- for ti in tis:
- if ti.task_id in task_ids and ti.execution_date in execution_dates:
- self.assertEqual(ti.state, state)
- else:
- for old_ti in old_tis:
- if (old_ti.task_id == ti.task_id
- and old_ti.execution_date == ti.execution_date):
- self.assertEqual(ti.state, old_ti.state)
-
- def test_mark_tasks_now(self):
- # set one task to success but do not commit
- snapshot = self.snapshot_state(self.dag1, self.execution_dates)
- task = self.dag1.get_task("runme_1")
- altered = set_state(task=task, execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=False)
- self.assertEqual(len(altered), 1)
- self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
- None, snapshot)
-
- # set one and only one task to success
- altered = set_state(task=task, execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=True)
- self.assertEqual(len(altered), 1)
- self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
- State.SUCCESS, snapshot)
-
- # set no tasks
- altered = set_state(task=task, execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=True)
- self.assertEqual(len(altered), 0)
- self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
- State.SUCCESS, snapshot)
-
- # set task to other than success
- altered = set_state(task=task, execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.FAILED, commit=True)
- self.assertEqual(len(altered), 1)
- self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
- State.FAILED, snapshot)
-
- # dont alter other tasks
- snapshot = self.snapshot_state(self.dag1, self.execution_dates)
- task = self.dag1.get_task("runme_0")
- altered = set_state(task=task, execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=True)
- self.assertEqual(len(altered), 1)
- self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
- State.SUCCESS, snapshot)
-
- def test_mark_downstream(self):
- # test downstream
- snapshot = self.snapshot_state(self.dag1, self.execution_dates)
- task = self.dag1.get_task("runme_1")
- relatives = task.get_flat_relatives(upstream=False)
- task_ids = [t.task_id for t in relatives]
- task_ids.append(task.task_id)
-
- altered = set_state(task=task, execution_date=self.execution_dates[0],
- upstream=False, downstream=True, future=False,
- past=False, state=State.SUCCESS, commit=True)
- self.assertEqual(len(altered), 3)
- self.verify_state(self.dag1, task_ids, [self.execution_dates[0]],
- State.SUCCESS, snapshot)
-
- def test_mark_upstream(self):
- # test upstream
- snapshot = self.snapshot_state(self.dag1, self.execution_dates)
- task = self.dag1.get_task("run_after_loop")
- relatives = task.get_flat_relatives(upstream=True)
- task_ids = [t.task_id for t in relatives]
- task_ids.append(task.task_id)
-
- altered = set_state(task=task, execution_date=self.execution_dates[0],
- upstream=True, downstream=False, future=False,
- past=False, state=State.SUCCESS, commit=True)
- self.assertEqual(len(altered), 4)
- self.verify_state(self.dag1, task_ids, [self.execution_dates[0]],
- State.SUCCESS, snapshot)
-
- def test_mark_tasks_future(self):
- # set one task to success towards end of scheduled dag runs
- snapshot = self.snapshot_state(self.dag1, self.execution_dates)
- task = self.dag1.get_task("runme_1")
- altered = set_state(task=task, execution_date=self.execution_dates[0],
- upstream=False, downstream=False, future=True,
- past=False, state=State.SUCCESS, commit=True)
- self.assertEqual(len(altered), 2)
- self.verify_state(self.dag1, [task.task_id], self.execution_dates,
- State.SUCCESS, snapshot)
-
- def test_mark_tasks_past(self):
- # set one task to success towards end of scheduled dag runs
- snapshot = self.snapshot_state(self.dag1, self.execution_dates)
- task = self.dag1.get_task("runme_1")
- altered = set_state(task=task, execution_date=self.execution_dates[1],
- upstream=False, downstream=False, future=False,
- past=True, state=State.SUCCESS, commit=True)
- self.assertEqual(len(altered), 2)
- self.verify_state(self.dag1, [task.task_id], self.execution_dates,
- State.SUCCESS, snapshot)
-
- def test_mark_tasks_subdag(self):
- # set one task to success towards end of scheduled dag runs
- task = self.dag2.get_task("section-1")
- relatives = task.get_flat_relatives(upstream=False)
- task_ids = [t.task_id for t in relatives]
- task_ids.append(task.task_id)
-
- altered = set_state(task=task, execution_date=self.execution_dates[0],
- upstream=False, downstream=True, future=False,
- past=False, state=State.SUCCESS, commit=True)
- self.assertEqual(len(altered), 14)
-
- # cannot use snapshot here as that will require drilling down the
- # the sub dag tree essentially recreating the same code as in the
- # tested logic.
- self.verify_state(self.dag2, task_ids, [self.execution_dates[0]],
- State.SUCCESS, [])
-
- def tearDown(self):
- self.dag1.clear()
- self.dag2.clear()
-
- # just to make sure we are fully cleaned up
- self.session.query(models.DagRun).delete()
- self.session.query(models.TaskInstance).delete()
- self.session.commit()
-
- self.session.close()
-
-class TestMarkDAGRun(unittest.TestCase):
- def setUp(self):
- self.dagbag = models.DagBag(include_examples=True)
- self.dag1 = self.dagbag.dags['test_example_bash_operator']
- self.dag2 = self.dagbag.dags['example_subdag_operator']
-
- self.execution_dates = [days_ago(3), days_ago(2), days_ago(1)]
-
- self.session = Session()
-
- def verify_dag_run_states(self, dag, date, state=State.SUCCESS):
- drs = models.DagRun.find(dag_id=dag.dag_id, execution_date=date)
- dr = drs[0]
- self.assertEqual(dr.get_state(), state)
- tis = dr.get_task_instances(session=self.session)
- for ti in tis:
- self.assertEqual(ti.state, state)
-
- def test_set_running_dag_run_state(self):
- date = self.execution_dates[0]
- dr = self.dag1.create_dagrun(
- run_id='manual__' + datetime.now().isoformat(),
- state=State.RUNNING,
- execution_date=date,
- session=self.session
- )
- for ti in dr.get_task_instances(session=self.session):
- ti.set_state(State.RUNNING, self.session)
-
- altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True)
-
- # All of the task should be altered
- self.assertEqual(len(altered), len(self.dag1.tasks))
- self.verify_dag_run_states(self.dag1, date)
-
- def test_set_success_dag_run_state(self):
- date = self.execution_dates[0]
-
- dr = self.dag1.create_dagrun(
- run_id='manual__' + datetime.now().isoformat(),
- state=State.SUCCESS,
- execution_date=date,
- session=self.session
- )
- for ti in dr.get_task_instances(session=self.session):
- ti.set_state(State.SUCCESS, self.session)
-
- altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True)
-
- # None of the task should be altered
- self.assertEqual(len(altered), 0)
- self.verify_dag_run_states(self.dag1, date)
-
- def test_set_failed_dag_run_state(self):
- date = self.execution_dates[0]
- dr = self.dag1.create_dagrun(
- run_id='manual__' + datetime.now().isoformat(),
- state=State.FAILED,
- execution_date=date,
- session=self.session
- )
- dr.get_task_instance('runme_0').set_state(State.FAILED, self.session)
-
- altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True)
-
- # All of the task should be altered
- self.assertEqual(len(altered), len(self.dag1.tasks))
- self.verify_dag_run_states(self.dag1, date)
-
- def test_set_mixed_dag_run_state(self):
- """
- This test checks function set_dag_run_state with mixed task instance
- state.
- """
- date = self.execution_dates[0]
- dr = self.dag1.create_dagrun(
- run_id='manual__' + datetime.now().isoformat(),
- state=State.FAILED,
- execution_date=date,
- session=self.session
- )
- # success task
- dr.get_task_instance('runme_0').set_state(State.SUCCESS, self.session)
- # skipped task
- dr.get_task_instance('runme_1').set_state(State.SKIPPED, self.session)
- # retry task
- dr.get_task_instance('runme_2').set_state(State.UP_FOR_RETRY, self.session)
- # queued task
- dr.get_task_instance('also_run_this').set_state(State.QUEUED, self.session)
- # running task
- dr.get_task_instance('run_after_loop').set_state(State.RUNNING, self.session)
- # failed task
- dr.get_task_instance('run_this_last').set_state(State.FAILED, self.session)
-
- altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=True)
-
- self.assertEqual(len(altered), len(self.dag1.tasks) - 1) # only 1 task succeeded
- self.verify_dag_run_states(self.dag1, date)
-
- def test_set_state_without_commit(self):
- date = self.execution_dates[0]
-
- # Running dag run and task instances
- dr = self.dag1.create_dagrun(
- run_id='manual__' + datetime.now().isoformat(),
- state=State.RUNNING,
- execution_date=date,
- session=self.session
- )
- for ti in dr.get_task_instances(session=self.session):
- ti.set_state(State.RUNNING, self.session)
-
- altered = set_dag_run_state(self.dag1, date, state=State.SUCCESS, commit=False)
-
- # All of the task should be altered
- self.assertEqual(len(altered), len(self.dag1.tasks))
-
- # Both dag run and task instances' states should remain the same
- self.verify_dag_run_states(self.dag1, date, State.RUNNING)
-
- def test_set_state_with_multiple_dagruns(self):
- dr1 = self.dag2.create_dagrun(
- run_id='manual__' + datetime.now().isoformat(),
- state=State.FAILED,
- execution_date=self.execution_dates[0],
- session=self.session
- )
- dr2 = self.dag2.create_dagrun(
- run_id='manual__' + datetime.now().isoformat(),
- state=State.FAILED,
- execution_date=self.execution_dates[1],
- session=self.session
- )
- dr3 = self.dag2.create_dagrun(
- run_id='manual__' + datetime.now().isoformat(),
- state=State.RUNNING,
- execution_date=self.execution_dates[2],
- session=self.session
- )
-
- altered = set_dag_run_state(self.dag2, self.execution_dates[1],
- state=State.SUCCESS, commit=True)
-
- # Recursively count number of tasks in the dag
- def count_dag_tasks(dag):
- count = len(dag.tasks)
- subdag_counts = [count_dag_tasks(subdag) for subdag in dag.subdags]
- count += sum(subdag_counts)
- return count
-
- self.assertEqual(len(altered), count_dag_tasks(self.dag2))
- self.verify_dag_run_states(self.dag2, self.execution_dates[1])
-
- # Make sure other dag status are not changed
- dr1 = models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[0])
- dr1 = dr1[0]
- self.assertEqual(dr1.get_state(), State.FAILED)
- dr3 = models.DagRun.find(dag_id=self.dag2.dag_id, execution_date=self.execution_dates[2])
- dr3 = dr3[0]
- self.assertEqual(dr3.get_state(), State.RUNNING)
-
- def test_set_dag_run_state_edge_cases(self):
- # Dag does not exist
- altered = set_dag_run_state(None, self.execution_dates[0])
- self.assertEqual(len(altered), 0)
-
- # Invalid execution date
- altered = set_dag_run_state(self.dag1, None)
- self.assertEqual(len(altered), 0)
- self.assertRaises(AssertionError, set_dag_run_state, self.dag1, timedelta(microseconds=-1))
-
- # DagRun does not exist
- # This will throw AssertionError since dag.latest_execution_date does not exist
- self.assertRaises(AssertionError, set_dag_run_state, self.dag1, self.execution_dates[0])
-
- def tearDown(self):
- self.dag1.clear()
- self.dag2.clear()
-
- self.session.query(models.DagRun).delete()
- self.session.query(models.TaskInstance).delete()
- self.session.query(models.DagStat).delete()
- self.session.commit()
-
-if __name__ == '__main__':
- unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/core.py
----------------------------------------------------------------------
diff --git a/tests/core.py b/tests/core.py
index 8ccd4e7..259b61d 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -1062,12 +1062,34 @@ class CoreTest(unittest.TestCase):
class CliTests(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(CliTests, cls).setUpClass()
+ cls._cleanup()
+
def setUp(self):
+ super(CliTests, self).setUp()
configuration.load_test_config()
app = application.create_app()
app.config['TESTING'] = True
self.parser = cli.CLIFactory.get_parser()
self.dagbag = models.DagBag(dag_folder=DEV_NULL, include_examples=True)
+ self.session = Session()
+
+ def tearDown(self):
+ self._cleanup(session=self.session)
+ super(CliTests, self).tearDown()
+
+ @staticmethod
+ def _cleanup(session=None):
+ if session is None:
+ session = Session()
+
+ session.query(models.Pool).delete()
+ session.query(models.Variable).delete()
+ session.commit()
+ session.close()
def test_cli_list_dags(self):
args = self.parser.parse_args(['list_dags', '--report'])
@@ -1100,8 +1122,8 @@ class CliTests(unittest.TestCase):
cli.connections(self.parser.parse_args(['connections', '--list']))
stdout = mock_stdout.getvalue()
conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]]
- for ii, line in enumerate(stdout.split('\n'))
- if ii % 2 == 1]
+ for ii, line in enumerate(stdout.split('\n'))
+ if ii % 2 == 1]
conns = [conn for conn in conns if len(conn) > 0]
# Assert that some of the connections are present in the output as
@@ -1365,14 +1387,27 @@ class CliTests(unittest.TestCase):
'-c', 'NOT JSON'])
)
- def test_pool(self):
- # Checks if all subcommands are properly received
- cli.pool(self.parser.parse_args([
- 'pool', '-s', 'foo', '1', '"my foo pool"']))
- cli.pool(self.parser.parse_args([
- 'pool', '-g', 'foo']))
- cli.pool(self.parser.parse_args([
- 'pool', '-x', 'foo']))
+ def test_pool_create(self):
+ cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
+ self.assertEqual(self.session.query(models.Pool).count(), 1)
+
+ def test_pool_get(self):
+ cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
+ try:
+ cli.pool(self.parser.parse_args(['pool', '-g', 'foo']))
+ except Exception as e:
+ self.fail("The 'pool -g foo' command raised unexpectedly: %s" % e)
+
+ def test_pool_delete(self):
+ cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
+ cli.pool(self.parser.parse_args(['pool', '-x', 'foo']))
+ self.assertEqual(self.session.query(models.Pool).count(), 0)
+
+ def test_pool_no_args(self):
+ try:
+ cli.pool(self.parser.parse_args(['pool']))
+ except Exception as e:
+ self.fail("The 'pool' command raised unexpectedly: %s" % e)
def test_variables(self):
# Checks if all subcommands are properly received
@@ -1426,10 +1461,6 @@ class CliTests(unittest.TestCase):
self.assertEqual('original', models.Variable.get('bar'))
self.assertEqual('{"foo": "bar"}', models.Variable.get('foo'))
- session = settings.Session()
- session.query(Variable).delete()
- session.commit()
- session.close()
os.remove('variables1.json')
os.remove('variables2.json')
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9958aa9d/tests/www/api/experimental/test_endpoints.py
----------------------------------------------------------------------
diff --git a/tests/www/api/experimental/test_endpoints.py b/tests/www/api/experimental/test_endpoints.py
index dacee32..65a6f75 100644
--- a/tests/www/api/experimental/test_endpoints.py
+++ b/tests/www/api/experimental/test_endpoints.py
@@ -19,22 +19,35 @@ from urllib.parse import quote_plus
from airflow import configuration
from airflow.api.common.experimental.trigger_dag import trigger_dag
-from airflow.models import DagBag, DagRun, TaskInstance
+from airflow.models import DagBag, DagRun, Pool, TaskInstance
from airflow.settings import Session
from airflow.www import app as application
-class ApiExperimentalTests(unittest.TestCase):
+class TestApiExperimental(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestApiExperimental, cls).setUpClass()
+ session = Session()
+ session.query(DagRun).delete()
+ session.query(TaskInstance).delete()
+ session.commit()
+ session.close()
def setUp(self):
+ super(TestApiExperimental, self).setUp()
configuration.load_test_config()
app = application.create_app(testing=True)
self.app = app.test_client()
+
+ def tearDown(self):
session = Session()
session.query(DagRun).delete()
session.query(TaskInstance).delete()
session.commit()
session.close()
+ super(TestApiExperimental, self).tearDown()
def test_task_info(self):
url_template = '/api/experimental/dags/{}/tasks/{}'
@@ -62,7 +75,7 @@ class ApiExperimentalTests(unittest.TestCase):
url_template = '/api/experimental/dags/{}/dag_runs'
response = self.app.post(
url_template.format('example_bash_operator'),
- data=json.dumps(dict(run_id='my_run' + datetime.now().isoformat())),
+ data=json.dumps({'run_id': 'my_run' + datetime.now().isoformat()}),
content_type="application/json"
)
@@ -70,7 +83,7 @@ class ApiExperimentalTests(unittest.TestCase):
response = self.app.post(
url_template.format('does_not_exist_dag'),
- data=json.dumps(dict()),
+ data=json.dumps({}),
content_type="application/json"
)
self.assertEqual(404, response.status_code)
@@ -88,7 +101,7 @@ class ApiExperimentalTests(unittest.TestCase):
# Test Correct execution
response = self.app.post(
url_template.format(dag_id),
- data=json.dumps(dict(execution_date=datetime_string)),
+ data=json.dumps({'execution_date': datetime_string}),
content_type="application/json"
)
self.assertEqual(200, response.status_code)
@@ -103,7 +116,7 @@ class ApiExperimentalTests(unittest.TestCase):
# Test error for nonexistent dag
response = self.app.post(
url_template.format('does_not_exist_dag'),
- data=json.dumps(dict(execution_date=execution_date.isoformat())),
+ data=json.dumps({'execution_date': execution_date.isoformat()}),
content_type="application/json"
)
self.assertEqual(404, response.status_code)
@@ -111,7 +124,7 @@ class ApiExperimentalTests(unittest.TestCase):
# Test error for bad datetime format
response = self.app.post(
url_template.format(dag_id),
- data=json.dumps(dict(execution_date='not_a_datetime')),
+ data=json.dumps({'execution_date': 'not_a_datetime'}),
content_type="application/json"
)
self.assertEqual(400, response.status_code)
@@ -122,7 +135,9 @@ class ApiExperimentalTests(unittest.TestCase):
task_id = 'also_run_this'
execution_date = datetime.now().replace(microsecond=0)
datetime_string = quote_plus(execution_date.isoformat())
- wrong_datetime_string = quote_plus(datetime(1990, 1, 1, 1, 1, 1).isoformat())
+ wrong_datetime_string = quote_plus(
+ datetime(1990, 1, 1, 1, 1, 1).isoformat()
+ )
# Create DagRun
trigger_dag(dag_id=dag_id,
@@ -139,7 +154,8 @@ class ApiExperimentalTests(unittest.TestCase):
# Test error for nonexistent dag
response = self.app.get(
- url_template.format('does_not_exist_dag', datetime_string, task_id),
+ url_template.format('does_not_exist_dag', datetime_string,
+ task_id),
)
self.assertEqual(404, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
@@ -164,3 +180,122 @@ class ApiExperimentalTests(unittest.TestCase):
)
self.assertEqual(400, response.status_code)
self.assertIn('error', response.data.decode('utf-8'))
+
+
+class TestPoolApiExperimental(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ super(TestPoolApiExperimental, cls).setUpClass()
+ session = Session()
+ session.query(Pool).delete()
+ session.commit()
+ session.close()
+
+ def setUp(self):
+ super(TestPoolApiExperimental, self).setUp()
+ configuration.load_test_config()
+ app = application.create_app(testing=True)
+ self.app = app.test_client()
+ self.session = Session()
+ self.pools = []
+ for i in range(2):
+ name = 'experimental_%s' % (i + 1)
+ pool = Pool(
+ pool=name,
+ slots=i,
+ description=name,
+ )
+ self.session.add(pool)
+ self.pools.append(pool)
+ self.session.commit()
+ self.pool = self.pools[0]
+
+ def tearDown(self):
+ self.session.query(Pool).delete()
+ self.session.commit()
+ self.session.close()
+ super(TestPoolApiExperimental, self).tearDown()
+
+ def _get_pool_count(self):
+ response = self.app.get('/api/experimental/pools')
+ self.assertEqual(response.status_code, 200)
+ return len(json.loads(response.data.decode('utf-8')))
+
+ def test_get_pool(self):
+ response = self.app.get(
+ '/api/experimental/pools/{}'.format(self.pool.pool),
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(json.loads(response.data.decode('utf-8')),
+ self.pool.to_json())
+
+ def test_get_pool_non_existing(self):
+ response = self.app.get('/api/experimental/pools/foo')
+ self.assertEqual(response.status_code, 404)
+ self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
+ "Pool 'foo' doesn't exist")
+
+ def test_get_pools(self):
+ response = self.app.get('/api/experimental/pools')
+ self.assertEqual(response.status_code, 200)
+ pools = json.loads(response.data.decode('utf-8'))
+ self.assertEqual(len(pools), 2)
+ for i, pool in enumerate(sorted(pools, key=lambda p: p['pool'])):
+ self.assertDictEqual(pool, self.pools[i].to_json())
+
+ def test_create_pool(self):
+ response = self.app.post(
+ '/api/experimental/pools',
+ data=json.dumps({
+ 'name': 'foo',
+ 'slots': 1,
+ 'description': '',
+ }),
+ content_type='application/json',
+ )
+ self.assertEqual(response.status_code, 200)
+ pool = json.loads(response.data.decode('utf-8'))
+ self.assertEqual(pool['pool'], 'foo')
+ self.assertEqual(pool['slots'], 1)
+ self.assertEqual(pool['description'], '')
+ self.assertEqual(self._get_pool_count(), 3)
+
+ def test_create_pool_with_bad_name(self):
+ for name in ('', ' '):
+ response = self.app.post(
+ '/api/experimental/pools',
+ data=json.dumps({
+ 'name': name,
+ 'slots': 1,
+ 'description': '',
+ }),
+ content_type='application/json',
+ )
+ self.assertEqual(response.status_code, 400)
+ self.assertEqual(
+ json.loads(response.data.decode('utf-8'))['error'],
+ "Pool name shouldn't be empty",
+ )
+ self.assertEqual(self._get_pool_count(), 2)
+
+ def test_delete_pool(self):
+ response = self.app.delete(
+ '/api/experimental/pools/{}'.format(self.pool.pool),
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(json.loads(response.data.decode('utf-8')),
+ self.pool.to_json())
+ self.assertEqual(self._get_pool_count(), 1)
+
+ def test_delete_pool_non_existing(self):
+ response = self.app.delete(
+ '/api/experimental/pools/foo',
+ )
+ self.assertEqual(response.status_code, 404)
+ self.assertEqual(json.loads(response.data.decode('utf-8'))['error'],
+ "Pool 'foo' doesn't exist")
+
+
+if __name__ == '__main__':
+ unittest.main()