Skip to content

Commit

Permalink
Multi cluster access (#50)
Browse files Browse the repository at this point in the history
Add connection configuration into cassandraTable

Co-authored-by: Ivan Salamakha <[email protected]>
  • Loading branch information
VanyaDNDZ and Ivan Salamakha authored Aug 3, 2022
1 parent af88190 commit 40ca6e5
Show file tree
Hide file tree
Showing 6 changed files with 574 additions and 65 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ for more details. *But* it exposes one additional method:
* `split_size` sets the size in the number of CQL rows in each partition (defaults to `100000`)
* `fetch_size` sets the number of rows to fetch per request from Cassandra (defaults to `1000`)
* `consistency_level` sets with which consistency level to read the data (defaults to `LOCAL_ONE`)
* `connection_config` map with connection information that can be used to connect non-default cluster (defaults to None)


### pyspark.RDD
Expand Down Expand Up @@ -311,6 +312,15 @@ sc \
.collect()
```

Reading from different clusters::
```python
rdd_one = sc \
.cassandraTable("keyspace", "table_one", connection_config={"spark_cassandra_connection_host": "cas-1"})

rdd_two = sc \
.cassandraTable("keyspace", "table_two", connection_config={"spark_cassandra_connection_host": "cas-2"})
```

Storing data in Cassandra::

```python
Expand All @@ -330,6 +340,14 @@ rdd.saveToCassandra(
"table",
ttl=timedelta(hours=1),
)

# Storing to non-default cluster
rdd.saveToCassandra(
"keyspace",
"table",
ttl=timedelta(hours=1),
connection_config={"spark_cassandra_connection_host": "cas-2"}
)
```

Modify CQL collections::
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark_cassandra/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def __str__(self):
', '.join('%s=%s' % (k, v) for k, v in self.settings().items())
)

class ConnectionConf(_Conf):
def __init__(self, spark_cassandra_connection_host):
self.spark_cassandra_connection_host = spark_cassandra_connection_host

class ReadConf(_Conf):
def __init__(self, split_count=None, split_size=None, fetch_size=None,
Expand Down
82 changes: 47 additions & 35 deletions python/pyspark_cassandra/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# limitations under the License.

import sys
from functools import partial
from copy import copy
from itertools import groupby
from operator import itemgetter

from pyspark.rdd import RDD

from .conf import ReadConf, WriteConf
from .conf import ReadConf, WriteConf, ConnectionConf
from .format import ColumnSelector, RowFormat
from .types import Row
from .util import as_java_array, as_java_object, helper
Expand All @@ -33,7 +34,7 @@

def saveToCassandra(rdd, keyspace=None, table=None, columns=None,
row_format=None, keyed=None,
write_conf=None, **write_conf_kwargs):
write_conf=None, connection_config=None, **write_conf_kwargs):
"""
Saves an RDD to Cassandra. The RDD is expected to contain dicts with
keys mapping to CQL columns.
Expand Down Expand Up @@ -63,6 +64,8 @@ def saveToCassandra(rdd, keyspace=None, table=None, columns=None,
@param write_conf(WriteConf):
A WriteConf object to use when saving to Cassandra
@param connection_config(ConnectionConf)
A ConnectionConf object to use when saving to non-default Cassandra cluster
@param **write_conf_kwargs:
WriteConf parameters to use when saving to Cassandra
"""
Expand All @@ -88,21 +91,22 @@ def saveToCassandra(rdd, keyspace=None, table=None, columns=None,
columns = as_java_array(rdd.ctx._gateway, "String",
columns) if columns else None

helper(rdd.ctx) \
.saveToCassandra(
rdd._jrdd,
keyspace,
table,
columns,
row_format,
keyed,
write_conf,
)
save = partial(helper(rdd.ctx).saveToCassandra, rdd._jrdd,
keyspace,
table,
columns,
row_format,
keyed,
write_conf)
if connection_config:
conn_conf = as_java_object(rdd.ctx._gateway, ConnectionConf.build(**connection_config).settings())
save = partial(save, conn_conf)
save()


def deleteFromCassandra(rdd, keyspace=None, table=None, deleteColumns=None,
keyColumns=None, row_format=None, keyed=None,
write_conf=None, **write_conf_kwargs):
write_conf=None, connection_config=None, **write_conf_kwargs):
"""
Delete data from Cassandra table, using data from the RDD as primary
keys. Uses the specified column names.
Expand Down Expand Up @@ -136,6 +140,8 @@ def deleteFromCassandra(rdd, keyspace=None, table=None, deleteColumns=None,
@param write_conf(WriteConf):
A WriteConf object to use when saving to Cassandra
@param connection_config(ConnectionConf)
A ConnectionConf object to use when saving to non-default Cassandra cluster
@param **write_conf_kwargs:
WriteConf parameters to use when saving to Cassandra
"""
Expand All @@ -157,18 +163,18 @@ def deleteFromCassandra(rdd, keyspace=None, table=None, deleteColumns=None,
if deleteColumns else None
keyColumns = as_java_array(rdd.ctx._gateway, "String", keyColumns) \
if keyColumns else None

helper(rdd.ctx) \
.deleteFromCassandra(
rdd._jrdd,
keyspace,
table,
deleteColumns,
keyColumns,
row_format,
keyed,
write_conf,
)
fn_delete = partial(helper(rdd.ctx).deleteFromCassandra, rdd._jrdd,
keyspace,
table,
deleteColumns,
keyColumns,
row_format,
keyed,
write_conf)
if connection_config:
conn_conf = as_java_object(rdd.ctx._gateway, ConnectionConf.build(**connection_config).settings())
fn_delete = partial(fn_delete, conn_conf)
fn_delete()


class _CassandraRDD(RDD):
Expand Down Expand Up @@ -311,7 +317,7 @@ def __copy__(self):


class CassandraTableScanRDD(_CassandraRDD):
def __init__(self, ctx, keyspace, table, row_format=None, read_conf=None,
def __init__(self, ctx, keyspace, table, row_format=None, read_conf=None, connection_config=None,
**read_conf_kwargs):
super(CassandraTableScanRDD, self).__init__(ctx, keyspace, table,
row_format, read_conf,
Expand All @@ -320,9 +326,11 @@ def __init__(self, ctx, keyspace, table, row_format=None, read_conf=None,
self._key_by = ColumnSelector.none()

read_conf = as_java_object(ctx._gateway, self.read_conf.settings())

self.crdd = self._helper \
.cassandraTable(ctx._jsc, keyspace, table, read_conf)
fn_read_table = partial(self._helper.cassandraTable, ctx._jsc, keyspace, table, read_conf)
if connection_config:
conn_conf = as_java_object(ctx._gateway, ConnectionConf.build(**connection_config).settings())
fn_read_table = partial(fn_read_table, conn_conf)
self.crdd = fn_read_table()

def by_primary_key(self):
return self.key_by(primary_key=True)
Expand Down Expand Up @@ -411,7 +419,7 @@ def asDataFrames(self, *index_by):
return rdd


def joinWithCassandraTable(left_rdd, keyspace, table):
def joinWithCassandraTable(left_rdd, keyspace, table, connection_config=None):
"""
Join an RDD with a Cassandra table on the partition key. Use .on(...)
to specifiy other columns to join on. .select(...), .where(...) and
Expand All @@ -425,20 +433,24 @@ def joinWithCassandraTable(left_rdd, keyspace, table):
The keyspace to join on
@param table(string):
The CQL table to join on.
@param connection_config(ConnectionConf)
A ConnectionConf object to use when saving to non-default Cassandra cluster
"""

return CassandraJoinRDD(left_rdd, keyspace, table)
return CassandraJoinRDD(left_rdd, keyspace, table, connection_config)


class CassandraJoinRDD(_CassandraRDD):
"""
TODO
"""

def __init__(self, left_rdd, keyspace, table):
def __init__(self, left_rdd, keyspace, table, connection_config=None):
super(CassandraJoinRDD, self).__init__(left_rdd.ctx, keyspace, table)
self.crdd = self._helper\
.joinWithCassandraTable(left_rdd._jrdd, keyspace, table)
fn_read_rdd = partial(self._helper.joinWithCassandraTable, left_rdd._jrdd, keyspace, table,)
if connection_config:
conn_conf = as_java_object(left_rdd.ctx._gateway, ConnectionConf.build(**connection_config).settings())
fn_read_rdd = partial(fn_read_rdd, conn_conf)
self.crdd = fn_read_rdd()

def on(self, *columns):
columns = as_java_array(self.ctx._gateway, "String",
Expand Down
44 changes: 29 additions & 15 deletions python/pyspark_cassandra/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
from pyspark.streaming.dstream import DStream

from pyspark_cassandra.conf import WriteConf
from pyspark_cassandra.conf import WriteConf, ConnectionConf
from pyspark_cassandra.util import as_java_object, as_java_array
from pyspark_cassandra.util import helper


def saveToCassandra(dstream, keyspace, table, columns=None, row_format=None,
keyed=None,
write_conf=None, **write_conf_kwargs):
write_conf=None, connection_config=None, **write_conf_kwargs):
ctx = dstream._ssc._sc
gw = ctx._gateway

Expand All @@ -30,14 +30,20 @@ def saveToCassandra(dstream, keyspace, table, columns=None, row_format=None,
# convert the columns to a string array
columns = as_java_array(gw, "String", columns) if columns else None

return helper(ctx).saveToCassandra(dstream._jdstream, keyspace, table,
columns, row_format,
keyed, write_conf)
fn_save = partial(helper(ctx).saveToCassandra, dstream._jdstream, keyspace, table,
columns, row_format,
keyed, write_conf)

if connection_config:
conn_conf = as_java_object(ctx._gateway, ConnectionConf.build(**connection_config).settings())
fn_save = partial(fn_save, conn_conf)

return fn_save()


def deleteFromCassandra(dstream, keyspace=None, table=None, deleteColumns=None,
keyColumns=None,
row_format=None, keyed=None, write_conf=None,
row_format=None, keyed=None, write_conf=None, connection_config=None,
**write_conf_kwargs):
"""Delete data from Cassandra table, using data from the RDD as primary
keys. Uses the specified column names.
Expand Down Expand Up @@ -86,15 +92,19 @@ def deleteFromCassandra(dstream, keyspace=None, table=None, deleteColumns=None,
deleteColumns) if deleteColumns else None
keyColumns = as_java_array(gw, "String", keyColumns) \
if keyColumns else None
fn_delete = partial(helper(ctx).deleteFromCassandra, dstream._jdstream, keyspace, table,
deleteColumns, keyColumns,
row_format,
keyed, write_conf)
if connection_config:
conn_conf = as_java_object(ctx._gateway, ConnectionConf.build(**connection_config).settings())
fn_delete = partial(fn_delete, conn_conf)

return helper(ctx).deleteFromCassandra(dstream._jdstream, keyspace, table,
deleteColumns, keyColumns,
row_format,
keyed, write_conf)
return fn_delete()


def joinWithCassandraTable(dstream, keyspace, table, selected_columns=None,
join_columns=None):
join_columns=None, connection_config=None):
"""Joins a DStream (a stream of RDDs) with a Cassandra table
Arguments:
Expand All @@ -121,9 +131,13 @@ def joinWithCassandraTable(dstream, keyspace, table, selected_columns=None,
join_columns) if join_columns else None

h = helper(ctx)
dstream = h.joinWithCassandraTable(dstream._jdstream, keyspace, table,
selected_columns,
join_columns)
fn_read_join = partial(h.joinWithCassandraTable, dstream._jdstream, keyspace, table,
selected_columns,
join_columns)
if connection_config:
conn_conf = as_java_object(ctx._gateway, ConnectionConf.build(**connection_config).settings())
fn_read_join = partial(fn_read_join, conn_conf)
dstream = fn_read_join()
dstream = h.pickleRows(dstream)
dstream = h.javaDStream(dstream)

Expand Down
Loading

0 comments on commit 40ca6e5

Please sign in to comment.