diff --git a/README.md b/README.md index 70b786e..8109375 100644 --- a/README.md +++ b/README.md @@ -203,6 +203,7 @@ for more details. *But* it exposes one additional method: * `split_size` sets the size in the number of CQL rows in each partition (defaults to `100000`) * `fetch_size` sets the number of rows to fetch per request from Cassandra (defaults to `1000`) * `consistency_level` sets with which consistency level to read the data (defaults to `LOCAL_ONE`) + * `connection_config` map with connection information that can be used to connect non-default cluster (defaults to None) ### pyspark.RDD @@ -311,6 +312,15 @@ sc \ .collect() ``` +Reading from different clusters:: +```python +rdd_one = sc \ + .cassandraTable("keyspace", "table_one", connection_config={"spark_cassandra_connection_host": "cas-1"}) + +rdd_two = sc \ + .cassandraTable("keyspace", "table_two", connection_config={"spark_cassandra_connection_host": "cas-2"}) +``` + Storing data in Cassandra:: ```python @@ -330,6 +340,14 @@ rdd.saveToCassandra( "table", ttl=timedelta(hours=1), ) + +# Storing to non-default cluster +rdd.saveToCassandra( + "keyspace", + "table", + ttl=timedelta(hours=1), + connection_config={"spark_cassandra_connection_host": "cas-2"} +) ``` Modify CQL collections:: diff --git a/python/pyspark_cassandra/conf.py b/python/pyspark_cassandra/conf.py index 0eb8421..ff003a1 100644 --- a/python/pyspark_cassandra/conf.py +++ b/python/pyspark_cassandra/conf.py @@ -34,6 +34,9 @@ def __str__(self): ', '.join('%s=%s' % (k, v) for k, v in self.settings().items()) ) +class ConnectionConf(_Conf): + def __init__(self, spark_cassandra_connection_host): + self.spark_cassandra_connection_host = spark_cassandra_connection_host class ReadConf(_Conf): def __init__(self, split_count=None, split_size=None, fetch_size=None, diff --git a/python/pyspark_cassandra/rdd.py b/python/pyspark_cassandra/rdd.py index db8b7a9..b7fab56 100644 --- a/python/pyspark_cassandra/rdd.py +++ b/python/pyspark_cassandra/rdd.py @@ -11,13 +11,14 @@ # limitations under the License. import sys +from functools import partial from copy import copy from itertools import groupby from operator import itemgetter from pyspark.rdd import RDD -from .conf import ReadConf, WriteConf +from .conf import ReadConf, WriteConf, ConnectionConf from .format import ColumnSelector, RowFormat from .types import Row from .util import as_java_array, as_java_object, helper @@ -33,7 +34,7 @@ def saveToCassandra(rdd, keyspace=None, table=None, columns=None, row_format=None, keyed=None, - write_conf=None, **write_conf_kwargs): + write_conf=None, connection_config=None, **write_conf_kwargs): """ Saves an RDD to Cassandra. The RDD is expected to contain dicts with keys mapping to CQL columns. @@ -63,6 +64,8 @@ def saveToCassandra(rdd, keyspace=None, table=None, columns=None, @param write_conf(WriteConf): A WriteConf object to use when saving to Cassandra + @param connection_config(ConnectionConf) + A ConnectionConf object to use when saving to non-default Cassandra cluster @param **write_conf_kwargs: WriteConf parameters to use when saving to Cassandra """ @@ -88,21 +91,22 @@ def saveToCassandra(rdd, keyspace=None, table=None, columns=None, columns = as_java_array(rdd.ctx._gateway, "String", columns) if columns else None - helper(rdd.ctx) \ - .saveToCassandra( - rdd._jrdd, - keyspace, - table, - columns, - row_format, - keyed, - write_conf, - ) + save = partial(helper(rdd.ctx).saveToCassandra, rdd._jrdd, + keyspace, + table, + columns, + row_format, + keyed, + write_conf) + if connection_config: + conn_conf = as_java_object(rdd.ctx._gateway, ConnectionConf.build(**connection_config).settings()) + save = partial(save, conn_conf) + save() def deleteFromCassandra(rdd, keyspace=None, table=None, deleteColumns=None, keyColumns=None, row_format=None, keyed=None, - write_conf=None, **write_conf_kwargs): + write_conf=None, connection_config=None, **write_conf_kwargs): """ Delete data from Cassandra table, using data from the RDD as primary keys. Uses the specified column names. @@ -136,6 +140,8 @@ def deleteFromCassandra(rdd, keyspace=None, table=None, deleteColumns=None, @param write_conf(WriteConf): A WriteConf object to use when saving to Cassandra + @param connection_config(ConnectionConf) + A ConnectionConf object to use when saving to non-default Cassandra cluster @param **write_conf_kwargs: WriteConf parameters to use when saving to Cassandra """ @@ -157,18 +163,18 @@ def deleteFromCassandra(rdd, keyspace=None, table=None, deleteColumns=None, if deleteColumns else None keyColumns = as_java_array(rdd.ctx._gateway, "String", keyColumns) \ if keyColumns else None - - helper(rdd.ctx) \ - .deleteFromCassandra( - rdd._jrdd, - keyspace, - table, - deleteColumns, - keyColumns, - row_format, - keyed, - write_conf, - ) + fn_delete = partial(helper(rdd.ctx).deleteFromCassandra, rdd._jrdd, + keyspace, + table, + deleteColumns, + keyColumns, + row_format, + keyed, + write_conf) + if connection_config: + conn_conf = as_java_object(rdd.ctx._gateway, ConnectionConf.build(**connection_config).settings()) + fn_delete = partial(fn_delete, conn_conf) + fn_delete() class _CassandraRDD(RDD): @@ -311,7 +317,7 @@ def __copy__(self): class CassandraTableScanRDD(_CassandraRDD): - def __init__(self, ctx, keyspace, table, row_format=None, read_conf=None, + def __init__(self, ctx, keyspace, table, row_format=None, read_conf=None, connection_config=None, **read_conf_kwargs): super(CassandraTableScanRDD, self).__init__(ctx, keyspace, table, row_format, read_conf, @@ -320,9 +326,11 @@ def __init__(self, ctx, keyspace, table, row_format=None, read_conf=None, self._key_by = ColumnSelector.none() read_conf = as_java_object(ctx._gateway, self.read_conf.settings()) - - self.crdd = self._helper \ - .cassandraTable(ctx._jsc, keyspace, table, read_conf) + fn_read_table = partial(self._helper.cassandraTable, ctx._jsc, keyspace, table, read_conf) + if connection_config: + conn_conf = as_java_object(ctx._gateway, ConnectionConf.build(**connection_config).settings()) + fn_read_table = partial(fn_read_table, conn_conf) + self.crdd = fn_read_table() def by_primary_key(self): return self.key_by(primary_key=True) @@ -411,7 +419,7 @@ def asDataFrames(self, *index_by): return rdd -def joinWithCassandraTable(left_rdd, keyspace, table): +def joinWithCassandraTable(left_rdd, keyspace, table, connection_config=None): """ Join an RDD with a Cassandra table on the partition key. Use .on(...) to specifiy other columns to join on. .select(...), .where(...) and @@ -425,9 +433,10 @@ def joinWithCassandraTable(left_rdd, keyspace, table): The keyspace to join on @param table(string): The CQL table to join on. + @param connection_config(ConnectionConf) + A ConnectionConf object to use when saving to non-default Cassandra cluster """ - - return CassandraJoinRDD(left_rdd, keyspace, table) + return CassandraJoinRDD(left_rdd, keyspace, table, connection_config) class CassandraJoinRDD(_CassandraRDD): @@ -435,10 +444,13 @@ class CassandraJoinRDD(_CassandraRDD): TODO """ - def __init__(self, left_rdd, keyspace, table): + def __init__(self, left_rdd, keyspace, table, connection_config=None): super(CassandraJoinRDD, self).__init__(left_rdd.ctx, keyspace, table) - self.crdd = self._helper\ - .joinWithCassandraTable(left_rdd._jrdd, keyspace, table) + fn_read_rdd = partial(self._helper.joinWithCassandraTable, left_rdd._jrdd, keyspace, table,) + if connection_config: + conn_conf = as_java_object(left_rdd.ctx._gateway, ConnectionConf.build(**connection_config).settings()) + fn_read_rdd = partial(fn_read_rdd, conn_conf) + self.crdd = fn_read_rdd() def on(self, *columns): columns = as_java_array(self.ctx._gateway, "String", diff --git a/python/pyspark_cassandra/streaming.py b/python/pyspark_cassandra/streaming.py index 65a86d6..3fbf3a9 100644 --- a/python/pyspark_cassandra/streaming.py +++ b/python/pyspark_cassandra/streaming.py @@ -9,18 +9,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from functools import partial from pyspark.serializers import AutoBatchedSerializer, PickleSerializer from pyspark.streaming.dstream import DStream -from pyspark_cassandra.conf import WriteConf +from pyspark_cassandra.conf import WriteConf, ConnectionConf from pyspark_cassandra.util import as_java_object, as_java_array from pyspark_cassandra.util import helper def saveToCassandra(dstream, keyspace, table, columns=None, row_format=None, keyed=None, - write_conf=None, **write_conf_kwargs): + write_conf=None, connection_config=None, **write_conf_kwargs): ctx = dstream._ssc._sc gw = ctx._gateway @@ -30,14 +30,20 @@ def saveToCassandra(dstream, keyspace, table, columns=None, row_format=None, # convert the columns to a string array columns = as_java_array(gw, "String", columns) if columns else None - return helper(ctx).saveToCassandra(dstream._jdstream, keyspace, table, - columns, row_format, - keyed, write_conf) + fn_save = partial(helper(ctx).saveToCassandra, dstream._jdstream, keyspace, table, + columns, row_format, + keyed, write_conf) + + if connection_config: + conn_conf = as_java_object(ctx._gateway, ConnectionConf.build(**connection_config).settings()) + fn_save = partial(fn_save, conn_conf) + + return fn_save() def deleteFromCassandra(dstream, keyspace=None, table=None, deleteColumns=None, keyColumns=None, - row_format=None, keyed=None, write_conf=None, + row_format=None, keyed=None, write_conf=None, connection_config=None, **write_conf_kwargs): """Delete data from Cassandra table, using data from the RDD as primary keys. Uses the specified column names. @@ -86,15 +92,19 @@ def deleteFromCassandra(dstream, keyspace=None, table=None, deleteColumns=None, deleteColumns) if deleteColumns else None keyColumns = as_java_array(gw, "String", keyColumns) \ if keyColumns else None + fn_delete = partial(helper(ctx).deleteFromCassandra, dstream._jdstream, keyspace, table, + deleteColumns, keyColumns, + row_format, + keyed, write_conf) + if connection_config: + conn_conf = as_java_object(ctx._gateway, ConnectionConf.build(**connection_config).settings()) + fn_delete = partial(fn_delete, conn_conf) - return helper(ctx).deleteFromCassandra(dstream._jdstream, keyspace, table, - deleteColumns, keyColumns, - row_format, - keyed, write_conf) + return fn_delete() def joinWithCassandraTable(dstream, keyspace, table, selected_columns=None, - join_columns=None): + join_columns=None, connection_config=None): """Joins a DStream (a stream of RDDs) with a Cassandra table Arguments: @@ -121,9 +131,13 @@ def joinWithCassandraTable(dstream, keyspace, table, selected_columns=None, join_columns) if join_columns else None h = helper(ctx) - dstream = h.joinWithCassandraTable(dstream._jdstream, keyspace, table, - selected_columns, - join_columns) + fn_read_join = partial(h.joinWithCassandraTable, dstream._jdstream, keyspace, table, + selected_columns, + join_columns) + if connection_config: + conn_conf = as_java_object(ctx._gateway, ConnectionConf.build(**connection_config).settings()) + fn_read_join = partial(fn_read_join, conn_conf) + dstream = fn_read_join() dstream = h.pickleRows(dstream) dstream = h.javaDStream(dstream) diff --git a/python/tests.py b/python/tests.py index 51aef19..9089ac0 100644 --- a/python/tests.py +++ b/python/tests.py @@ -37,7 +37,6 @@ class CassandraTestCase(unittest.TestCase): - keyspace = "test_pyspark_cassandra" def setUp(self): @@ -74,6 +73,14 @@ def rdd(self, keyspace=None, table=None, key=None, column=None, **kwargs): rdd = rdd.select(column) return rdd + def general_read_test(self, type_name, value=None): + self.read_test(type_name, value) + self.read_test_cluster(type_name, value) + + def general_read_write_test(self, type_name, value=None): + self.read_write_test(type_name, value) + self.read_write_test_cluster(type_name, value) + def read_test(self, type_name, value=None): rdd = self.rdd(key=type_name, column=type_name) self.assertEqual(rdd.count(), 1) @@ -87,6 +94,21 @@ def read_write_test(self, type_name, value): rdd.saveToCassandra(self.keyspace, self.table) return self.read_test(type_name, value) + def read_test_cluster(self, type_name, value=None): + rdd = self.rdd(key=type_name, column=type_name, + connection_config={"spark_cassandra_connection_host": "localhost"}) + self.assertEqual(rdd.count(), 1) + read = getattr(rdd.first(), type_name) + self.assertEqual(read, value) + return read + + def read_write_test_cluster(self, type_name, value): + row = {'key': type_name, type_name: value} + rdd = self.sc.parallelize([row]) + rdd.saveToCassandra(self.keyspace, self.table, + connection_config={"spark_cassandra_connection_host": "localhost"}) + return self.read_test_cluster(type_name, value) + class SimpleTypesTestBase(CassandraTestCase): table = "simple_types" @@ -113,28 +135,28 @@ def setUp(self): class SimpleTypesTest(SimpleTypesTestBase): def test_ascii(self): - self.read_write_test('ascii', 'some ascii') + self.general_read_write_test('ascii', 'some ascii') def test_bigint(self): - self.read_write_test('bigint', sys.maxsize) + self.general_read_write_test('bigint', sys.maxsize) def test_blob(self): - self.read_write_test('blob', bytearray('some blob'.encode('ascii'))) + self.general_read_write_test('blob', bytearray('some blob'.encode('ascii'))) def test_boolean(self): - self.read_write_test('boolean', False) + self.general_read_write_test('boolean', False) def test_date(self): - self.read_write_test('date', date(2018, 8, 1)) + self.general_read_write_test('date', date(2018, 8, 1)) def test_decimal(self): - self.read_write_test('decimal', Decimal(0.5)) + self.general_read_write_test('decimal', Decimal(0.5)) def test_double(self): - self.read_write_test('double', 0.5) + self.general_read_write_test('double', 0.5) def test_float(self): - self.read_write_test('float', 0.5) + self.general_read_write_test('float', 0.5) # TODO returns resolved hostname with ip address (hostname/ip, # e.g. /127.0.0.1), but doesn't accept with / ... @@ -142,10 +164,10 @@ def test_float(self): # self.read_write_test('inet', u'/127.0.0.1') def test_int(self): - self.read_write_test('int', 1) + self.general_read_write_test('int', 1) def test_text(self): - self.read_write_test('text', u'some text') + self.general_read_write_test('text', u'some text') # TODO implement test with datetime with tzinfo without depending on pytz # def test_timestamp(self): @@ -153,13 +175,13 @@ def test_text(self): def test_timeuuid(self): uuid = uuid_from_time(datetime(2015, 1, 1)) - self.read_write_test('timeuuid', uuid) + self.general_read_write_test('timeuuid', uuid) def test_varchar(self): - self.read_write_test('varchar', u'some varchar') + self.general_read_write_test('varchar', u'some varchar') def test_varint(self): - self.read_write_test('varint', 1) + self.general_read_write_test('varint', 1) def test_uuid(self): self.read_write_test('uuid', @@ -581,7 +603,6 @@ def test_write_conf(self): class StreamingTest(SimpleTypesTestBase): - interval = .1 @classmethod @@ -692,6 +713,89 @@ def test_composite_pk(self): # .limit() +class JoinClusterRDDTest(SimpleTypesTestBase): + def setUp(self): + super(JoinClusterRDDTest, self).setUp() + + def test_simple_pk(self): + table = 'join_rdd_test_simple_pk' + + self.session.execute(''' + CREATE TABLE IF NOT EXISTS ''' + table + ''' ( + key text primary key, value text + ) + ''') + self.session.execute('TRUNCATE %s' % table) + + rows = { + str(c): str(i) for i, c in + enumerate(string.ascii_lowercase) + } + + for k, v in rows.items(): + self.session.execute( + 'INSERT INTO ' + table + + ' (key, value) values (%s, %s)', (k, v) + ) + + rdd = self.sc.parallelize(rows.items()) + self.assertEqual(dict(rdd.collect()), rows) + + tbl = rdd.joinWithCassandraTable( + self.keyspace, table, + connection_config={"spark_cassandra_connection_host": "localhost"} + ) + joined = tbl.on('key').select('key', 'value').cache() + self.assertEqual(dict(joined.keys().collect()), + dict(joined.values().collect())) + for (k, v) in joined.collect(): + self.assertEqual(k, v) + + def test_composite_pk(self): + table = 'join_rdd_test_composite_pk' + + self.session.execute(''' + CREATE TABLE IF NOT EXISTS ''' + table + ''' ( + pk text, cc text, value text, + primary key (pk, cc) + ) + ''') + self.session.execute('TRUNCATE %s' % table) + + rows = [ + # (pk, cc, pk + '-' + cc) + (pk, cc, pk + '-' + cc) + for pk in string.ascii_lowercase[:3] + for cc in (str(i) for i in range(3)) + ] + + for row in rows: + self.session.execute( + 'INSERT INTO ' + table + + ' (pk, cc, value) values (%s, %s, %s)', + row + ) + + rdd = self.sc.parallelize(rows) + + joined = rdd.joinWithCassandraTable( + self.keyspace, table, + connection_config={"spark_cassandra_connection_host": "localhost"} + ).on('pk', 'cc') + self.assertEqual(sorted(zip(rows, rows)), + sorted(joined.map(tuple).collect())) + + joined = rdd.joinWithCassandraTable( + self.keyspace, table, + connection_config={"spark_cassandra_connection_host": "localhost"} + ).on('pk') + self.assertEqual(len(rows) * sqrt(len(rows)), joined.count()) + + # TODO test + # .where() + # .limit() + + class JoinDStreamTest(StreamingTest): def setUp(self): super(JoinDStreamTest, self).setUp() @@ -729,6 +833,44 @@ def test(self): self.assertEqual(left['text'], right['text']) self.assertEqual(len(right), 1) +class JoinClusterDStreamTest(StreamingTest): + def setUp(self): + super(JoinClusterDStreamTest, self).setUp() + self.joined_rows = self.sc.accumulator( + [], accum_param=AddingAccumulatorParam([])) + + def checkRDD(self, time, rdd): + self.joined_rows += rdd.collect() + + def test(self): + rows = list(chain(*self.rows)) + rows_by_key = {row['key']: row for row in rows} + + self.sc \ + .parallelize(rows) \ + .saveToCassandra(self.keyspace, self.table) + + self.stream \ + .joinWithCassandraTable( + self.keyspace, self.table, ['text'],['key'], + connection_config={"spark_cassandra_connection_host": "localhost"} + ).foreachRDD(self.checkRDD) + + self.ssc.start() + self.ssc.awaitTermination((self.count + 1) * self.interval) + self.ssc.stop(stopSparkContext=False, stopGraceFully=True) + + joined_rows = self.joined_rows.value + self.assertEqual(len(joined_rows), len(rows)) + for row in joined_rows: + self.assertEqual(len(row), 2) + left, right = row + + self.assertEqual(type(left), type(right)) + self.assertEqual(rows_by_key[left['key']], left) + self.assertEqual(left['text'], right['text']) + self.assertEqual(len(right), 1) + class DeleteFromCassandraStreamingTest(SimpleTypesTestBase): size = 10 @@ -811,6 +953,95 @@ def test_delete_all_rows_explicit(self): data = self.rdd() self.assertEqual(len(data.collect()), 0) +class DeleteFromCassandraClusterStreamingTest(SimpleTypesTestBase): + size = 10 + interval = .1 + + def setUp(self): + super(DeleteFromCassandraClusterStreamingTest, self).setUp() + self.ssc = StreamingContext(self.sc, self.interval) + + self.rdds = [self.sc.parallelize(range(0, self.size)).map( + lambda i: {'key': i, 'int': i, 'text': i})] + data = self.rdds[0] + data.saveToCassandra(self.keyspace, self.table) + + # verify the RDD length and actual content + data = self.rdd() + self.assertEqual(len(data.collect()), self.size) + + # verify we have actually data for `text` and `int` + row = data.select('text', 'int').where('key=?', '0').first() + self.assertEqual(row.text, u'0') + self.assertEqual(row.int, 0) + + # stream we will use in tests. + self.stream = self.ssc.queueStream(self.rdds) + + def test_delete_single_column(self): + self.stream.deleteFromCassandra( + self.keyspace, self.table, + deleteColumns=['text'], + connection_config={"spark_cassandra_connection_host": "localhost"} + ) + + self.ssc.start() + self.ssc.awaitTermination((self.size + 1) * self.interval) + self.ssc.stop(stopSparkContext=False, stopGraceFully=True) + + data = self.rdd() + self.assertEqual(len(data.collect()), self.size) + + # verify we have actually data for `text` and `int` + row = data.select('text', 'int').where('key=?', '0').first() + self.assertEqual(row.int, 0) + self.assertIsNone(row.text) + + def test_delete_2_columns(self): + self.stream.deleteFromCassandra( + self.keyspace, self.table, + deleteColumns=['text', 'int'], + connection_config={"spark_cassandra_connection_host": "localhost"} + ) + + self.ssc.start() + self.ssc.awaitTermination((self.size + 1) * self.interval) + self.ssc.stop(stopSparkContext=False, stopGraceFully=True) + + data = self.rdd() + self.assertEqual(len(data.collect()), self.size) + + # verify we have actually data for `text` and `int` + row = data.select('text', 'int').where('key=?', '0').first() + self.assertIsNone(row.int) + self.assertIsNone(row.text) + + def test_delete_all_rows_default(self): + self.stream.deleteFromCassandra( + self.keyspace, self.table, + connection_config={"spark_cassandra_connection_host": "localhost"} + ) + + self.ssc.start() + self.ssc.awaitTermination((self.size + 1) * self.interval) + self.ssc.stop(stopSparkContext=False, stopGraceFully=True) + + data = self.rdd() + self.assertEqual(len(data.collect()), 0) + + def test_delete_all_rows_explicit(self): + self.stream.deleteFromCassandra( + self.keyspace, self.table, keyColumns=['key'], + connection_config={"spark_cassandra_connection_host": "localhost"} + ) + + self.ssc.start() + self.ssc.awaitTermination((self.size + 1) * self.interval) + self.ssc.stop(stopSparkContext=False, stopGraceFully=True) + + data = self.rdd() + self.assertEqual(len(data.collect()), 0) + class DeleteFromCassandraTest(SimpleTypesTestBase): size = 1000 @@ -933,6 +1164,142 @@ def test_delete_all_rows_explicit(self): data = self.rdd() self.assertEqual(len(data.collect()), 0) +class DeleteFromCassandraClusterTest(SimpleTypesTestBase): + size = 1000 + + def setUp(self): + super(DeleteFromCassandraClusterTest, self).setUp() + data = self.sc.parallelize(range(0, self.size)).map( + lambda i: {'key': i, 'int': i, 'text': i}) + data.saveToCassandra(self.keyspace, self.table) + + def test_delete_selected_cols_seq(self): + data = self.rdd() + + # verify the RDD length. + self.assertEqual(len(data.collect()), self.size) + + # verify we have actually data for `text` and `int` + row = data.select('text', 'int').where('key=?', '0').first() + self.assertEqual(row.text, u'0') + self.assertEqual(row.int, 0) + + # delete content in the text table only. + data.deleteFromCassandra( + self.keyspace, self.table, + deleteColumns=['text'], + connection_config={"spark_cassandra_connection_host": "localhost"} + ) + + # verify the RDD length did not change. + self.assertEqual(len(data.collect()), self.size) + + # verify the `text` column got deleted. + row = data.select('text', 'int').where('key=?', '0').first() + self.assertIsNone(row.text) + self.assertEqual(row.int, 0) + + # delete content in the `int` column. + data.deleteFromCassandra( + self.keyspace, self.table, + deleteColumns=['int'], + connection_config={"spark_cassandra_connection_host": "localhost"} + ) + + # verify the RDD length did not change. + self.assertEqual(len(data.collect()), self.size) + + # verify the `int` column got deleted. + row = data.select('text', 'int').where('key=?', '0').first() + self.assertIsNone(row.text) + self.assertIsNone(row.int) + + # reload RDD and check the columns are still deleted. + data = self.rdd() + row = data.select('text', 'int').where('key=?', '0').first() + self.assertIsNone(row.text) + self.assertIsNone(row.int) + + def test_delete_selected_cols(self): + data = self.rdd() + + # verify the RDD length. + self.assertEqual(len(data.collect()), self.size) + + # verify we have actually data for `text` and `int` + row = data.select('text', 'int').where('key=?', '0').first() + self.assertEqual(row.text, u'0') + self.assertEqual(row.int, 0) + + # delete content in the text table only. + data.deleteFromCassandra( + self.keyspace, self.table, + deleteColumns=['text', 'int'], + connection_config={"spark_cassandra_connection_host": "localhost"} + ) + + # verify the RDD length did not change. + self.assertEqual(len(data.collect()), self.size) + + # verify the `text` and `int` columns got deleted. + row = data.select('text', 'int').where('key=?', '0').first() + self.assertIsNone(row.text) + self.assertIsNone(row.int) + + # reload RDD and check the columns are still deleted. + data = self.rdd() + row = data.select('text', 'int').where('key=?', '0').first() + self.assertIsNone(row.text) + self.assertIsNone(row.int) + + def test_delete_all_rows_default(self): + data = self.rdd() + + # verify the RDD length. + self.assertEqual(len(data.collect()), self.size) + + # verify we have actually data for `text` and `int` + row = data.select('key', 'text', 'int').where('key=?', '0').first() + self.assertEqual(row.key, u'0') + self.assertEqual(row.text, u'0') + self.assertEqual(row.int, 0) + + # delete all content + data.deleteFromCassandra( + self.keyspace, self.table, + connection_config={"spark_cassandra_connection_host": "localhost"} + ) + + # verify the RDD length. + self.assertEqual(len(data.collect()), 0) + + data = self.rdd() + self.assertEqual(len(data.collect()), 0) + + def test_delete_all_rows_explicit(self): + data = self.rdd() + + # verify the RDD length. + self.assertEqual(len(data.collect()), self.size) + + # verify we have actually data for `text` and `int` + row = data.select('key', 'text', 'int').where('key=?', '0').first() + self.assertEqual(row.key, u'0') + self.assertEqual(row.text, u'0') + self.assertEqual(row.int, 0) + + # delete all content + data.deleteFromCassandra( + self.keyspace, self.table, keyColumns=['key'], + connection_config={"spark_cassandra_connection_host": "localhost"} + ) + + # verify the RDD length. + self.assertEqual(len(data.collect()), 0) + + data = self.rdd() + self.assertEqual(len(data.collect()), 0) + class RegressionTest(CassandraTestCase): def test_64(self): diff --git a/src/main/scala/pyspark_cassandra/PythonHelper.scala b/src/main/scala/pyspark_cassandra/PythonHelper.scala index 79b8ae0..c9cac14 100644 --- a/src/main/scala/pyspark_cassandra/PythonHelper.scala +++ b/src/main/scala/pyspark_cassandra/PythonHelper.scala @@ -18,6 +18,7 @@ import java.lang.Boolean import java.util.{Map => JMap} import com.datastax.spark.connector._ +import com.datastax.spark.connector.cql.CassandraConnector import com.datastax.spark.connector.rdd._ import com.datastax.spark.connector.streaming.toDStreamFunctions import com.datastax.spark.connector.types.TypeConverter @@ -44,6 +45,13 @@ class PythonHelper() extends Serializable { jsc.sc.cassandraTable(keyspace, table).withReadConf(conf) } + def cassandraTable(jsc: JavaSparkContext, keyspace: String, table: String, readConf: JMap[String, Any], conConf: JMap[String, String]) = { + val conf = parseReadConf(jsc.sc, Some(readConf)) + implicit val rrf = new DeferringRowReaderFactory() + implicit val connector = CassandraConnector(jsc.sc.getConf.set("spark.cassandra.connection.host", conConf.get("spark_cassandra_connection_host"))) + jsc.sc.cassandraTable(keyspace, table).withReadConf(conf) + } + def select(rdd: CassandraRDD[UnreadRow], columns: Array[String]) = { rdd.select(columns.map { new ColumnName(_) @@ -106,6 +114,45 @@ class PythonHelper() extends Serializable { dstream.dstream.unpickle().saveToCassandra(keyspace, table, selectedColumns, conf) } + /* custom cluster -------------------------------------------------------- */ + /* rdds ------------------------------------------------------------------ */ + def saveToCassandra(rdd: JavaRDD[Array[Byte]], keyspace: String, table: String, columns: JMap[String, String], + rowFormat: Integer, keyed: Boolean, writeConf: JMap[String, Any], conConf: JMap[String, String]) = { + + val selectedColumns = columnSelector(columns) + val conf = parseWriteConf(Some(writeConf)) + + implicit val rwf = new GenericRowWriterFactory(Format(rowFormat), asBooleanOption(keyed)) + implicit val connector = CassandraConnector(rdd.sparkContext.getConf.set("spark.cassandra.connection.host", + conConf.get("spark_cassandra_connection_host"))) + rdd.rdd.unpickle().saveToCassandra(keyspace, table, selectedColumns, conf) + } + + def saveToCassandra(rdd: JavaRDD[Array[Byte]], keyspace: String, table: String, columns: Array[String], + rowFormat: Integer, keyed: Boolean, writeConf: JMap[String, Any], conConf: JMap[String, String]) = { + + val selectedColumns = columnSelector(columns) + val conf = parseWriteConf(Some(writeConf)) + + implicit val rwf = new GenericRowWriterFactory(Format(rowFormat), asBooleanOption(keyed)) + implicit val connector = CassandraConnector(rdd.sparkContext.getConf.set("spark.cassandra.connection.host", + conConf.get("spark_cassandra_connection_host"))) + rdd.rdd.unpickle().saveToCassandra(keyspace, table, selectedColumns, conf) + } + + /* dstreams -------------------------------------------------------------- */ + def saveToCassandra(dstream: JavaDStream[Array[Byte]], keyspace: String, table: String, columns: Array[String], + rowFormat: Integer, keyed: Boolean, writeConf: JMap[String, Any], conConf: JMap[String, String]) = { + + val selectedColumns = columnSelector(columns) + val conf = parseWriteConf(Some(writeConf)) + + implicit val rwf = new GenericRowWriterFactory(Format(rowFormat), asBooleanOption(keyed)) + implicit val connector = CassandraConnector(dstream.context().sparkContext.getConf.set("spark.cassandra.connection.host", + conConf.get("spark_cassandra_connection_host"))) + dstream.dstream.unpickle().saveToCassandra(keyspace, table, selectedColumns, conf) + } + /* ----------------------------------------------------------------------- */ /* join with cassandra tables -------------------------------------------- */ /* ----------------------------------------------------------------------- */ @@ -118,6 +165,15 @@ class PythonHelper() extends Serializable { rdd.rdd.unpickle().joinWithCassandraTable(keyspace, table) } + + def joinWithCassandraTable(rdd: JavaRDD[Array[Byte]], keyspace: String, table: String, conConf: JMap[String, String]): CassandraJoinRDD[Any, UnreadRow] = { + implicit val rwf = new GenericRowWriterFactory(None, None) + implicit val rrf = new DeferringRowReaderFactory() + implicit val connector = CassandraConnector(rdd.sparkContext.getConf.set("spark.cassandra.connection.host", + conConf.get("spark_cassandra_connection_host"))) + rdd.rdd.unpickle().joinWithCassandraTable(keyspace, table) + } + def on(rdd: CassandraJoinRDD[Any, UnreadRow], columns: Array[String]) = { rdd.on(columnSelector(columns, PartitionKeyColumns)) } @@ -133,6 +189,17 @@ class PythonHelper() extends Serializable { dstream.dstream.unpickle().joinWithCassandraTable(keyspace, table, columns, joinOn) } + + def joinWithCassandraTable(dstream: JavaDStream[Array[Byte]], keyspace: String, table: String, + selectedColumns: Array[String], joinColumns: Array[String], conConf: JMap[String, String]): DStream[(Any, UnreadRow)] = { + val columns = columnSelector(selectedColumns) + val joinOn = columnSelector(joinColumns, PartitionKeyColumns) + implicit val rwf = new GenericRowWriterFactory(None, None) + implicit val rrf = new DeferringRowReaderFactory() + implicit val connector = CassandraConnector(dstream.context().sparkContext.getConf.set("spark.cassandra.connection.host", + conConf.get("spark_cassandra_connection_host"))) + dstream.dstream.unpickle().joinWithCassandraTable(keyspace, table, columns, joinOn) + } /* ----------------------------------------------------------------------- */ /* delete from cassandra --------------------------------------------------*/ /* ----------------------------------------------------------------------- */ @@ -150,6 +217,20 @@ class PythonHelper() extends Serializable { rdd.rdd.unpickle().deleteFromCassandra(keyspace, table, deletes, keys, conf) } + def deleteFromCassandra(rdd: JavaRDD[Array[Byte]], keyspace: String, table: String, + deleteColumns: Array[String], keyColumns: Array[String], + rowFormat: Integer, keyed: Boolean, + writeConf: JMap[String, Any], + conConf: JMap[String, String]) = { + val deletes = columnSelector(deleteColumns, SomeColumns()) + val keys = columnSelector(keyColumns, PrimaryKeyColumns) + val conf = parseWriteConf(Some(writeConf)) + implicit val rwf = new GenericRowWriterFactory(Format(rowFormat), asBooleanOption(keyed)) + implicit val connector = CassandraConnector(rdd.sparkContext.getConf.set("spark.cassandra.connection.host", + conConf.get("spark_cassandra_connection_host"))) + rdd.rdd.unpickle().deleteFromCassandra(keyspace, table, deletes, keys, conf) + } + /* dstreams ------------------------------------------------------------------ */ def deleteFromCassandra(dstream: JavaDStream[Array[Byte]], keyspace: String, table: String, @@ -163,6 +244,20 @@ class PythonHelper() extends Serializable { dstream.dstream.unpickle().deleteFromCassandra(keyspace, table, deletes, keys, conf) } + def deleteFromCassandra(dstream: JavaDStream[Array[Byte]], keyspace: String, table: String, + deleteColumns: Array[String], keyColumns: Array[String], + rowFormat: Integer, keyed: Boolean, + writeConf: JMap[String, Any], + conConf: JMap[String, String]) = { + val deletes = columnSelector(deleteColumns, SomeColumns()) + val keys = columnSelector(keyColumns, PrimaryKeyColumns) + val conf = parseWriteConf(Some(writeConf)) + implicit val rwf = new GenericRowWriterFactory(Format(rowFormat), asBooleanOption(keyed)) + implicit val connector = CassandraConnector(dstream.context().sparkContext.getConf.set("spark.cassandra.connection.host", + conConf.get("spark_cassandra_connection_host"))) + dstream.dstream.unpickle().deleteFromCassandra(keyspace, table, deletes, keys, conf) + } + /* ----------------------------------------------------------------------- */ /* utilities for moving rdds and dstreams from and to pyspark ------------ */ /* ----------------------------------------------------------------------- */