You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2018/05/18 01:08:00 UTC
incubator-airflow git commit: [ARIFLOW-2458] Add cassandra-to-gcs
operator
Repository: incubator-airflow
Updated Branches:
refs/heads/master 8873a8df8 -> f5115b7e6
[ARIFLOW-2458] Add cassandra-to-gcs operator
Closes #3354 from jgao54/cassandra-to-gcs
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/f5115b7e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/f5115b7e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/f5115b7e
Branch: refs/heads/master
Commit: f5115b7e6a105e6baedd8efa9b4d4afc12ee880d
Parents: 8873a8d
Author: Joy Gao <Jo...@apache.org>
Authored: Fri May 18 02:01:41 2018 +0100
Committer: Kaxil Naik <ka...@apache.org>
Committed: Fri May 18 02:02:57 2018 +0100
----------------------------------------------------------------------
airflow/contrib/hooks/cassandra_hook.py | 88 +++++
airflow/contrib/operators/cassandra_to_gcs.py | 351 +++++++++++++++++++
airflow/models.py | 4 +
airflow/utils/db.py | 4 +
docs/code.rst | 2 +
setup.py | 7 +-
tests/contrib/hooks/test_cassandra_hook.py | 56 +++
.../operators/test_cassandra_to_gcs_operator.py | 92 +++++
8 files changed, 602 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/airflow/contrib/hooks/cassandra_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/cassandra_hook.py b/airflow/contrib/hooks/cassandra_hook.py
new file mode 100644
index 0000000..90046a8
--- /dev/null
+++ b/airflow/contrib/hooks/cassandra_hook.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from cassandra.cluster import Cluster
+from cassandra.policies import (RoundRobinPolicy, DCAwareRoundRobinPolicy,
+ TokenAwarePolicy, HostFilterPolicy,
+ WhiteListRoundRobinPolicy)
+from cassandra.auth import PlainTextAuthProvider
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+
+class CassandraHook(BaseHook, LoggingMixin):
+ """
+ Hook used to interact with Cassandra
+
+ Contact_points can be specified as a comma-separated string in the 'hosts'
+ field of the connection. Port can be specified in the port field of the
+ connection. Load_alancing_policy, ssl_options, cql_version can be specified
+ in the extra field of the connection.
+
+ For details of the Cluster config, see cassandra.cluster for more details.
+ """
+ def __init__(self, cassandra_conn_id='cassandra_default'):
+ conn = self.get_connection(cassandra_conn_id)
+
+ conn_config = {}
+ if conn.host:
+ conn_config['contact_points'] = conn.host.split(',')
+
+ if conn.port:
+ conn_config['port'] = int(conn.port)
+
+ if conn.login:
+ conn_config['auth_provider'] = PlainTextAuthProvider(
+ username=conn.login, password=conn.password)
+
+ lb_policy = self.get_policy(conn.extra_dejson.get('load_balancing_policy', None))
+ if lb_policy:
+ conn_config['load_balancing_policy'] = lb_policy
+
+ cql_version = conn.extra_dejson.get('cql_version', None)
+ if cql_version:
+ conn_config['cql_version'] = cql_version
+
+ ssl_options = conn.extra_dejson.get('ssl_options', None)
+ if ssl_options:
+ conn_config['ssl_options'] = ssl_options
+
+ self.cluster = Cluster(**conn_config)
+ self.keyspace = conn.schema
+
+ def get_conn(self):
+ """
+ Returns a cassandra connection object
+ """
+ return self.cluster.connect(self.keyspace)
+
+ def get_cluster(self):
+ return self.cluster
+
+ @classmethod
+ def get_policy(cls, policy_name):
+ policies = {
+ 'RoundRobinPolicy': RoundRobinPolicy,
+ 'DCAwareRoundRobinPolicy': DCAwareRoundRobinPolicy,
+ 'TokenAwarePolicy': TokenAwarePolicy,
+ 'HostFilterPolicy': HostFilterPolicy,
+ 'WhiteListRoundRobinPolicy': WhiteListRoundRobinPolicy,
+ }
+ return policies.get(policy_name)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/airflow/contrib/operators/cassandra_to_gcs.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/cassandra_to_gcs.py b/airflow/contrib/operators/cassandra_to_gcs.py
new file mode 100644
index 0000000..b4e216d
--- /dev/null
+++ b/airflow/contrib/operators/cassandra_to_gcs.py
@@ -0,0 +1,351 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import unicode_literals
+
+import json
+from builtins import str
+from base64 import b64encode
+from cassandra.util import Date, Time, SortedSet, OrderedMapSerializedKey
+from datetime import datetime
+from decimal import Decimal
+from six import text_type, binary_type, PY3
+from tempfile import NamedTemporaryFile
+from uuid import UUID
+
+from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
+from airflow.contrib.hooks.cassandra_hook import CassandraHook
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.utils.decorators import apply_defaults
+
+
+class CassandraToGoogleCloudStorageOperator(BaseOperator):
+ """
+ Copy data from Cassandra to Google cloud storage in JSON format
+
+ Note: Arrays of arrays are not supported.
+ """
+ template_fields = ('cql', 'bucket', 'filename', 'schema_filename',)
+ template_ext = ('.cql',)
+ ui_color = '#a0e08c'
+
+ @apply_defaults
+ def __init__(self,
+ cql,
+ bucket,
+ filename,
+ schema_filename=None,
+ approx_max_file_size_bytes=1900000000,
+ cassandra_conn_id='cassandra_default',
+ google_cloud_storage_conn_id='google_cloud_default',
+ delegate_to=None,
+ *args,
+ **kwargs):
+ """
+ :param cql: The CQL to execute on the Cassandra table.
+ :type cql: string
+ :param bucket: The bucket to upload to.
+ :type bucket: string
+ :param filename: The filename to use as the object name when uploading
+ to Google cloud storage. A {} should be specified in the filename
+ to allow the operator to inject file numbers in cases where the
+ file is split due to size.
+ :type filename: string
+ :param schema_filename: If set, the filename to use as the object name
+ when uploading a .json file containing the BigQuery schema fields
+ for the table that was dumped from MySQL.
+ :type schema_filename: string
+ :param approx_max_file_size_bytes: This operator supports the ability
+ to split large table dumps into multiple files (see notes in the
+ filenamed param docs above). Google cloud storage allows for files
+ to be a maximum of 4GB. This param allows developers to specify the
+ file size of the splits.
+ :type approx_max_file_size_bytes: long
+ :param cassandra_conn_id: Reference to a specific Cassandra hook.
+ :type cassandra_conn_id: string
+ :param google_cloud_storage_conn_id: Reference to a specific Google
+ cloud storage hook.
+ :type google_cloud_storage_conn_id: string
+ :param delegate_to: The account to impersonate, if any. For this to
+ work, the service account making the request must have domain-wide
+ delegation enabled.
+ :type delegate_to: string
+ """
+ super(CassandraToGoogleCloudStorageOperator, self).__init__(*args, **kwargs)
+ self.cql = cql
+ self.bucket = bucket
+ self.filename = filename
+ self.schema_filename = schema_filename
+ self.approx_max_file_size_bytes = approx_max_file_size_bytes
+ self.cassandra_conn_id = cassandra_conn_id
+ self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
+ self.delegate_to = delegate_to
+
+ # Default Cassandra to BigQuery type mapping
+ CQL_TYPE_MAP = {
+ 'BytesType': 'BYTES',
+ 'DecimalType': 'FLOAT',
+ 'UUIDType': 'STRING',
+ 'BooleanType': 'BOOL',
+ 'ByteType': 'INTEGER',
+ 'AsciiType': 'STRING',
+ 'FloatType': 'FLOAT',
+ 'DoubleType': 'FLOAT',
+ 'LongType': 'INTEGER',
+ 'Int32Type': 'INTEGER',
+ 'IntegerType': 'INTEGER',
+ 'InetAddressType': 'STRING',
+ 'CounterColumnType': 'INTEGER',
+ 'DateType': 'TIMESTAMP',
+ 'SimpleDateType': 'DATE',
+ 'TimestampType': 'TIMESTAMP',
+ 'TimeUUIDType': 'BYTES',
+ 'ShortType': 'INTEGER',
+ 'TimeType': 'TIME',
+ 'DurationType': 'INTEGER',
+ 'UTF8Type': 'STRING',
+ 'VarcharType': 'STRING',
+ }
+
+ def execute(self, context):
+ cursor = self._query_cassandra()
+ files_to_upload = self._write_local_data_files(cursor)
+
+ # If a schema is set, create a BQ schema JSON file.
+ if self.schema_filename:
+ files_to_upload.update(self._write_local_schema_file(cursor))
+
+ # Flush all files before uploading
+ for file_handle in files_to_upload.values():
+ file_handle.flush()
+
+ self._upload_to_gcs(files_to_upload)
+
+ # Close all temp file handles.
+ for file_handle in files_to_upload.values():
+ file_handle.close()
+
+ def _query_cassandra(self):
+ """
+ Queries cassandra and returns a cursor to the results.
+ """
+ hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id)
+ session = hook.get_conn()
+ cursor = session.execute(self.cql)
+ return cursor
+
+ def _write_local_data_files(self, cursor):
+ """
+ Takes a cursor, and writes results to a local file.
+
+ :return: A dictionary where keys are filenames to be used as object
+ names in GCS, and values are file handles to local files that
+ contain the data for the GCS objects.
+ """
+ file_no = 0
+ tmp_file_handle = NamedTemporaryFile(delete=True)
+ tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}
+ for row in cursor:
+ row_dict = self.generate_data_dict(row._fields, row)
+ s = json.dumps(row_dict)
+ if PY3:
+ s = s.encode('utf-8')
+ tmp_file_handle.write(s)
+
+ # Append newline to make dumps BigQuery compatible.
+ tmp_file_handle.write(b'\n')
+
+ if tmp_file_handle.tell() >= self.approx_max_file_size_bytes:
+ file_no += 1
+ tmp_file_handle = NamedTemporaryFile(delete=True)
+ tmp_file_handles[self.filename.format(file_no)] = tmp_file_handle
+
+ return tmp_file_handles
+
+ def _write_local_schema_file(self, cursor):
+ """
+ Takes a cursor, and writes the BigQuery schema for the results to a
+ local file system.
+
+ :return: A dictionary where key is a filename to be used as an object
+ name in GCS, and values are file handles to local files that
+ contains the BigQuery schema fields in .json format.
+ """
+ schema = []
+ tmp_schema_file_handle = NamedTemporaryFile(delete=True)
+
+ for name, type in zip(cursor.column_names, cursor.column_types):
+ schema.append(self.generate_schema_dict(name, type))
+ json_serialized_schema = json.dumps(schema)
+ if PY3:
+ json_serialized_schema = json_serialized_schema.encode('utf-8')
+
+ tmp_schema_file_handle.write(json_serialized_schema)
+ return {self.schema_filename: tmp_schema_file_handle}
+
+ def _upload_to_gcs(self, files_to_upload):
+ hook = GoogleCloudStorageHook(
+ google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
+ delegate_to=self.delegate_to)
+ for object, tmp_file_handle in files_to_upload.items():
+ hook.upload(self.bucket, object, tmp_file_handle.name, 'application/json')
+
+ @classmethod
+ def generate_data_dict(cls, names, values):
+ row_dict = {}
+ for name, value in zip(names, values):
+ row_dict.update({name: cls.convert_value(name, value)})
+ return row_dict
+
+ @classmethod
+ def convert_value(cls, name, value):
+ if not value:
+ return value
+ elif isinstance(value, (text_type, int, float, bool, dict)):
+ return value
+ elif isinstance(value, binary_type):
+ encoded_value = b64encode(value)
+ if PY3:
+ encoded_value = encoded_value.decode('ascii')
+ return encoded_value
+ elif isinstance(value, (datetime, Date, UUID)):
+ return str(value)
+ elif isinstance(value, Decimal):
+ return float(value)
+ elif isinstance(value, Time):
+ return str(value).split('.')[0]
+ elif isinstance(value, (list, SortedSet)):
+ return cls.convert_array_types(name, value)
+ elif hasattr(value, '_fields'):
+ return cls.convert_user_type(name, value)
+ elif isinstance(value, tuple):
+ return cls.convert_tuple_type(name, value)
+ elif isinstance(value, OrderedMapSerializedKey):
+ return cls.convert_map_type(name, value)
+ else:
+ raise AirflowException('unexpected value: ' + str(value))
+
+ @classmethod
+ def convert_array_types(cls, name, value):
+ return [cls.convert_value(name, nested_value) for nested_value in value]
+
+ @classmethod
+ def convert_user_type(cls, name, value):
+ """
+ Converts a user type to RECORD that contains n fields, where n is the
+ number of attributes. Each element in the user type class will be converted to its
+ corresponding data type in BQ.
+ """
+ names = value._fields
+ values = [cls.convert_value(name, getattr(value, name)) for name in names]
+ return cls.generate_data_dict(names, values)
+
+ @classmethod
+ def convert_tuple_type(cls, name, value):
+ """
+ Converts a tuple to RECORD that contains n fields, each will be converted
+ to its corresponding data type in bq and will be named 'field_<index>', where
+ index is determined by the order of the tuple elments defined in cassandra.
+ """
+ names = ['field_' + str(i) for i in range(len(value))]
+ values = [cls.convert_value(name, value) for name, value in zip(names, value)]
+ return cls.generate_data_dict(names, values)
+
+ @classmethod
+ def convert_map_type(cls, name, value):
+ """
+ Converts a map to a repeated RECORD that contains two fields: 'key' and 'value',
+ each will be converted to its corresopnding data type in BQ.
+ """
+ converted_map = []
+ for k, v in zip(value.keys(), value.values()):
+ converted_map.append({
+ 'key': cls.convert_value('key', k),
+ 'value': cls.convert_value('value', v)
+ })
+ return converted_map
+
+ @classmethod
+ def generate_schema_dict(cls, name, type):
+ field_schema = dict()
+ field_schema.update({'name': name})
+ field_schema.update({'type': cls.get_bq_type(type)})
+ field_schema.update({'mode': cls.get_bq_mode(type)})
+ fields = cls.get_bq_fields(name, type)
+ if fields:
+ field_schema.update({'fields': fields})
+ return field_schema
+
+ @classmethod
+ def get_bq_fields(cls, name, type):
+ fields = []
+
+ if not cls.is_simple_type(type):
+ names, types = [], []
+
+ if cls.is_array_type(type) and cls.is_record_type(type.subtypes[0]):
+ names = type.subtypes[0].fieldnames
+ types = type.subtypes[0].subtypes
+ elif cls.is_record_type(type):
+ names = type.fieldnames
+ types = type.subtypes
+
+ if types and not names and type.cassname == 'TupleType':
+ names = ['field_' + str(i) for i in range(len(types))]
+ elif types and not names and type.cassname == 'MapType':
+ names = ['key', 'value']
+
+ for name, type in zip(names, types):
+ field = cls.generate_schema_dict(name, type)
+ fields.append(field)
+
+ return fields
+
+ @classmethod
+ def is_simple_type(cls, type):
+ return type.cassname in CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP
+
+ @classmethod
+ def is_array_type(cls, type):
+ return type.cassname in ['ListType', 'SetType']
+
+ @classmethod
+ def is_record_type(cls, type):
+ return type.cassname in ['UserType', 'TupleType', 'MapType']
+
+ @classmethod
+ def get_bq_type(cls, type):
+ if cls.is_simple_type(type):
+ return CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP[type.cassname]
+ elif cls.is_record_type(type):
+ return 'RECORD'
+ elif cls.is_array_type(type):
+ return cls.get_bq_type(type.subtypes[0])
+ else:
+ raise AirflowException('Not a supported type: ' + type.cassname)
+
+ @classmethod
+ def get_bq_mode(cls, type):
+ if cls.is_array_type(type) or type.cassname == 'MapType':
+ return 'REPEATED'
+ elif cls.is_record_type(type) or cls.is_simple_type(type):
+ return 'NULLABLE'
+ else:
+ raise AirflowException('Not a supported type: ' + type.cassname)
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index c9fee0c..7aab4b5 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -603,6 +603,7 @@ class Connection(Base, LoggingMixin):
('snowflake', 'Snowflake',),
('segment', 'Segment',),
('azure_data_lake', 'Azure Data Lake'),
+ ('cassandra', 'Cassandra',),
]
def __init__(
@@ -753,6 +754,9 @@ class Connection(Base, LoggingMixin):
elif self.conn_type == 'azure_data_lake':
from airflow.contrib.hooks.azure_data_lake_hook import AzureDataLakeHook
return AzureDataLakeHook(azure_data_lake_conn_id=self.conn_id)
+ elif self.conn_type == 'cassandra':
+ from airflow.contrib.hooks.cassandra_hook import CassandraHook
+ return CassandraHook(cassandra_conn_id=self.conn_id)
except:
pass
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/airflow/utils/db.py
----------------------------------------------------------------------
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index adda6fd..270939a 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -276,6 +276,10 @@ def initdb(rbac=False):
models.Connection(
conn_id='azure_data_lake_default', conn_type='azure_data_lake',
extra='{"tenant": "<TENANT>", "account_name": "<ACCOUNTNAME>" }'))
+ merge_conn(
+ models.Connection(
+ conn_id='cassandra_default', conn_type='cassandra',
+ host='localhost', port=9042))
# Known event types
KET = models.KnownEventType
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/docs/code.rst
----------------------------------------------------------------------
diff --git a/docs/code.rst b/docs/code.rst
index 857bf67..1737d15 100644
--- a/docs/code.rst
+++ b/docs/code.rst
@@ -121,6 +121,7 @@ Operators
.. autoclass:: airflow.contrib.operators.bigquery_table_delete_operator.BigQueryTableDeleteOperator
.. autoclass:: airflow.contrib.operators.bigquery_to_bigquery.BigQueryToBigQueryOperator
.. autoclass:: airflow.contrib.operators.bigquery_to_gcs.BigQueryToCloudStorageOperator
+.. autoclass:: airflow.contrib.operators.cassandra_to_gcs.CassandraToGoogleCloudStorageOperator
.. autoclass:: airflow.contrib.operators.databricks_operator.DatabricksSubmitRunOperator
.. autoclass:: airflow.contrib.operators.dataflow_operator.DataFlowJavaOperator
.. autoclass:: airflow.contrib.operators.dataflow_operator.DataflowTemplateOperator
@@ -354,6 +355,7 @@ Community contributed hooks
.. autoclass:: airflow.contrib.hooks.aws_hook.AwsHook
.. autoclass:: airflow.contrib.hooks.aws_lambda_hook.AwsLambdaHook
.. autoclass:: airflow.contrib.hooks.bigquery_hook.BigQueryHook
+.. autoclass:: airflow.contrib.hooks.cassandra_hook.CassandraHook
.. autoclass:: airflow.contrib.hooks.cloudant_hook.CloudantHook
.. autoclass:: airflow.contrib.hooks.databricks_hook.DatabricksHook
.. autoclass:: airflow.contrib.hooks.datadog_hook.DatadogHook
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/setup.py
----------------------------------------------------------------------
diff --git a/setup.py b/setup.py
index 9813ea2..97b6883 100644
--- a/setup.py
+++ b/setup.py
@@ -114,7 +114,7 @@ azure_data_lake = [
'azure-mgmt-datalake-store==0.4.0',
'azure-datalake-store==0.0.19'
]
-sendgrid = ['sendgrid>=5.2.0']
+cassandra = ['cassandra-driver>=3.13.0']
celery = [
'celery>=4.0.2',
'flower>=0.7.3'
@@ -184,6 +184,7 @@ s3 = ['boto3>=1.7.0']
salesforce = ['simple-salesforce>=0.72']
samba = ['pysmbclient>=0.1.3']
segment = ['analytics-python>=1.2.9']
+sendgrid = ['sendgrid>=5.2.0']
slack = ['slackclient>=1.0.0']
snowflake = ['snowflake-connector-python>=1.5.2',
'snowflake-sqlalchemy>=1.1.0']
@@ -194,7 +195,8 @@ webhdfs = ['hdfs[dataframe,avro,kerberos]>=2.0.4']
winrm = ['pywinrm==0.2.2']
zendesk = ['zdesk']
-all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid + pinot
+all_dbs = postgres + mysql + hive + mssql + hdfs + vertica + cloudant + druid + pinot \
+ + cassandra
devel = [
'click',
'freezegun',
@@ -290,6 +292,7 @@ def do_setup():
'async': async,
'azure_blob_storage': azure_blob_storage,
'azure_data_lake': azure_data_lake,
+ 'cassandra': cassandra,
'celery': celery,
'cgroups': cgroups,
'cloudant': cloudant,
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/tests/contrib/hooks/test_cassandra_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_cassandra_hook.py b/tests/contrib/hooks/test_cassandra_hook.py
new file mode 100644
index 0000000..42afd9e
--- /dev/null
+++ b/tests/contrib/hooks/test_cassandra_hook.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+import mock
+
+from airflow import configuration
+from airflow.contrib.hooks.cassandra_hook import CassandraHook
+from cassandra.cluster import Cluster
+from cassandra.policies import TokenAwarePolicy
+from airflow import models
+from airflow.utils import db
+
+
+class CassandraHookTest(unittest.TestCase):
+ def setUp(self):
+ configuration.load_test_config()
+ db.merge_conn(
+ models.Connection(
+ conn_id='cassandra_test', conn_type='cassandra',
+ host='host-1,host-2', port='9042', schema='test_keyspace',
+ extra='{"load_balancing_policy":"TokenAwarePolicy"'))
+
+ def test_get_conn(self):
+ with mock.patch.object(Cluster, "connect") as mock_connect, \
+ mock.patch("socket.getaddrinfo", return_value=[]) as mock_getaddrinfo:
+ mock_connect.return_value = 'session'
+ hook = CassandraHook(cassandra_conn_id='cassandra_test')
+ hook.get_conn()
+ mock_getaddrinfo.assert_called()
+ mock_connect.assert_called_once_with('test_keyspace')
+
+ cluster = hook.get_cluster()
+ self.assertEqual(cluster.contact_points, ['host-1', 'host-2'])
+ self.assertEqual(cluster.port, 9042)
+ self.assertTrue(isinstance(cluster.load_balancing_policy, TokenAwarePolicy))
+
+
+if __name__ == '__main__':
+ unittest.main()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/f5115b7e/tests/contrib/operators/test_cassandra_to_gcs_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_cassandra_to_gcs_operator.py b/tests/contrib/operators/test_cassandra_to_gcs_operator.py
new file mode 100644
index 0000000..add115f
--- /dev/null
+++ b/tests/contrib/operators/test_cassandra_to_gcs_operator.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import unicode_literals
+
+import unittest
+import mock
+from builtins import str
+from airflow.contrib.operators.cassandra_to_gcs import \
+ CassandraToGoogleCloudStorageOperator
+
+
+class CassandraToGCSTest(unittest.TestCase):
+
+ @mock.patch('airflow.contrib.operators.gcs_to_s3.GoogleCloudStorageHook.upload')
+ @mock.patch('airflow.contrib.hooks.cassandra_hook.CassandraHook.get_conn')
+ def test_execute(self, upload, get_conn):
+ operator = CassandraToGoogleCloudStorageOperator(
+ task_id='test-cas-to-gcs',
+ cql='select * from keyspace1.table1',
+ bucket='test-bucket',
+ filename='data.json',
+ schema_filename='schema.json')
+
+ operator.execute(None)
+
+ self.assertTrue(get_conn.called_once())
+ self.assertTrue(upload.called_once())
+
+ def test_convert_value(self):
+ op = CassandraToGoogleCloudStorageOperator
+ self.assertEquals(op.convert_value('None', None), None)
+ self.assertEquals(op.convert_value('int', 1), 1)
+ self.assertEquals(op.convert_value('float', 1.0), 1.0)
+ self.assertEquals(op.convert_value('str', "text"), "text")
+ self.assertEquals(op.convert_value('bool', True), True)
+ self.assertEquals(op.convert_value('dict', {"a": "b"}), {"a": "b"})
+
+ from datetime import datetime
+ now = datetime.now()
+ self.assertEquals(op.convert_value('datetime', now), str(now))
+
+ from cassandra.util import Date
+ date_str = '2018-01-01'
+ date = Date(date_str)
+ self.assertEquals(op.convert_value('date', date), str(date_str))
+
+ import uuid
+ test_uuid = uuid.uuid4()
+ self.assertEquals(op.convert_value('uuid', test_uuid), str(test_uuid))
+
+ from decimal import Decimal
+ d = Decimal(1.0)
+ self.assertEquals(op.convert_value('decimal', d), float(d))
+
+ from base64 import b64encode
+ b = b'abc'
+ encoded_b = b64encode(b).decode('ascii')
+ self.assertEquals(op.convert_value('binary', b), encoded_b)
+
+ from cassandra.util import Time
+ time = Time(0)
+ self.assertEquals(op.convert_value('time', time), '00:00:00')
+
+ date_str_lst = ['2018-01-01', '2018-01-02', '2018-01-03']
+ date_lst = [Date(d) for d in date_str_lst]
+ self.assertEquals(op.convert_value('list', date_lst), date_str_lst)
+
+ date_tpl = tuple(date_lst)
+ self.assertEquals(op.convert_value('tuple', date_tpl),
+ {'field_0': '2018-01-01',
+ 'field_1': '2018-01-02',
+ 'field_2': '2018-01-03', })
+
+
+if __name__ == '__main__':
+ unittest.main()