Skip to content

Commit

Permalink
Merge pull request #3 from narma/zio_refactoring_and_finish_null_values
Browse files Browse the repository at this point in the history
Zio refactoring and finish null values
  • Loading branch information
narma authored Mar 17, 2022
2 parents ccb26c5 + 7b3ef20 commit a6ac648
Show file tree
Hide file tree
Showing 11 changed files with 206 additions and 134 deletions.
8 changes: 5 additions & 3 deletions src/it/resources/migration/1__test_tables.cql
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
create table tests.test_data(
id bigint,
data text,
count int,
dataset frozen<set<int>>,
PRIMARY KEY (id)
);

insert into tests.test_data (id, data) values (0, null);
insert into tests.test_data (id, data) values (1, 'one');
insert into tests.test_data (id, data) values (2, 'two');
insert into tests.test_data (id, data, count, dataset) values (0, null, null, null);
insert into tests.test_data (id, data, count, dataset) values (1, 'one', 10, {});
insert into tests.test_data (id, data, count, dataset) values (2, 'two', 20, {201});
insert into tests.test_data (id, data) values (3, 'three');

create table tests.test_data_multiple_keys(
Expand Down
181 changes: 93 additions & 88 deletions src/it/scala/zio/cassandra/session/cql/CqlSpec.scala

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions src/main/scala/zio/cassandra/session/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ object Session {
override def keyspace: Option[CqlIdentifier] = underlying.getKeyspace.toScala
}

def live: ZManaged[Has[CqlSessionBuilder], Throwable, Session] =
ZManaged.serviceWithManaged[CqlSessionBuilder] { cqlSessionBuilder =>
make(cqlSessionBuilder)
}

def make(builder: => CqlSessionBuilder): TaskManaged[Session] =
ZManaged
.make(Task.fromCompletionStage(builder.buildAsync())) { session =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ trait CassandraTypeMapper[Scala] {
def classType: Class[Cassandra]
def toCassandra(in: Scala, dataType: DataType): Cassandra
def fromCassandra(in: Cassandra, dataType: DataType): Scala
def allowNullable: Boolean = false
}

object CassandraTypeMapper {
Expand Down Expand Up @@ -246,5 +247,7 @@ object CassandraTypeMapper {

override def fromCassandra(in: Cassandra, dataType: DataType): Option[A] =
Option(in).map(ev.fromCassandra(_, dataType))

override def allowNullable: Boolean = true
}
}
19 changes: 17 additions & 2 deletions src/main/scala/zio/cassandra/session/cql/FromUdtValue.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package zio.cassandra.session.cql

import com.datastax.oss.driver.api.core.cql.Row
import com.datastax.oss.driver.api.core.data.{ GettableByIndex, UdtValue }
import zio.cassandra.session.cql.FromUdtValue.{ make, makeWithFieldName }

Expand All @@ -23,7 +24,16 @@ object FromUdtValue extends LowerPriorityFromUdtValue with LowestPriorityFromUdt

def deriveReads[A](implicit ev: FromUdtValue.Object[A]): Reads[A] = (row: GettableByIndex, index: Int) => {
val udtValue = row.getUdtValue(index)
ev.convert(FieldName.Unused, udtValue)
try ev.convert(FieldName.Unused, udtValue)
catch {
case UnexpectedNullValueInUdt.NullValueInUdt(udtValue, fieldName) =>
throw new UnexpectedNullValueInUdt(
row.asInstanceOf[Row],
index,
udtValue,
fieldName
) // FIXME: get rid of .asInstanceOf
}
}

// only allowed to summon fully built out FromUdtValue instances which are built by Shapeless machinery
Expand Down Expand Up @@ -71,7 +81,12 @@ trait LowerPriorityFromUdtValue {
ev: CassandraTypeMapper[A]
): FromUdtValue[A] =
makeWithFieldName[A] { (fieldName, udtValue) =>
ev.fromCassandra(udtValue.get(fieldName, ev.classType), udtValue.getType(fieldName))
if (udtValue.isNull(fieldName)) {
if (ev.allowNullable)
None.asInstanceOf[A]
else throw UnexpectedNullValueInUdt.NullValueInUdt(udtValue, fieldName)
} else
ev.fromCassandra(udtValue.get(fieldName, ev.classType), udtValue.getType(fieldName))
}
}

Expand Down
20 changes: 2 additions & 18 deletions src/main/scala/zio/cassandra/session/cql/Reads.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ trait Reads[T] { self =>

def read(row: GettableByIndex, index: Int): Task[T] =
if (row.isNull(index)) {
Task.fail(new UnexpectedNullValue(row, index))
Task.fail(UnexpectedNullValueInColumn(row.asInstanceOf[Row], index))
} else {
Task(readUnsafe(row, index))
}
Expand All @@ -28,22 +28,6 @@ trait Reads[T] { self =>
(row: GettableByIndex, index: Int) => f(self.readUnsafe(row, index))
}

class UnexpectedNullValue(row: GettableByIndex, index: Int) extends RuntimeException() {
override def getMessage: String =
row match {
case row: Row =>
val cl = row.getColumnDefinitions.get(index)
val table = cl.getTable.toString
val column = cl.getName.toString
val keyspace = cl.getKeyspace.toString
val tpe = cl.getType.asCql(true, true)

s"Read NULL value from column $column with type $tpe at $keyspace.$table. Row ${row.getFormattedContents}"
case _ =>
s"Read NULL value from column at index $index"
}
}

object Reads extends ReadsLowerPriority with ReadsLowestPriority {
def apply[T](implicit r: Reads[T]): Reads[T] = r

Expand Down Expand Up @@ -148,7 +132,7 @@ trait ReadsLowestPriority {

override def read(row: GettableByIndex, index: Int): Task[A] =
if (row.isNull(index)) {
Task.fail(new UnexpectedNullValue(row, index))
Task.fail(UnexpectedNullValueInColumn(row.asInstanceOf[Row], index))
} else {
val tpe = row.getType(index)
if (tpe.isInstanceOf[UserDefinedType]) {
Expand Down
42 changes: 42 additions & 0 deletions src/main/scala/zio/cassandra/session/cql/UnexpectedNullValue.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package zio.cassandra.session.cql

import com.datastax.oss.driver.api.core.cql.Row
import com.datastax.oss.driver.api.core.data.UdtValue

import scala.util.control.NoStackTrace

sealed trait UnexpectedNullValue extends Throwable

case class UnexpectedNullValueInColumn(row: Row, index: Int) extends RuntimeException() with UnexpectedNullValue {
override def getMessage: String = {
val cl = row.getColumnDefinitions.get(index)
val table = cl.getTable.toString
val column = cl.getName.toString
val keyspace = cl.getKeyspace.toString
val tpe = cl.getType.asCql(true, true)

s"Read NULL value from $keyspace.$table column $column expected $tpe. Row ${row.getFormattedContents}"
}
}

case class UnexpectedNullValueInUdt(row: Row, index: Int, udt: UdtValue, fieldName: String)
extends RuntimeException()
with UnexpectedNullValue {
override def getMessage: String = {
val cl = row.getColumnDefinitions.get(index)
val table = cl.getTable.toString
val column = cl.getName.toString
val keyspace = cl.getKeyspace.toString
val tpe = cl.getType.asCql(true, true)

val udtTpe = udt.getType(fieldName)

s"Read NULL value from $keyspace.$table inside UDT column $column with type $tpe. NULL value in $fieldName, expected type $udtTpe. Row ${row.getFormattedContents}"
}

}

object UnexpectedNullValueInUdt {
private[cql] case class NullValueInUdt(udtValue: UdtValue, fieldName: String) extends NoStackTrace
}

14 changes: 9 additions & 5 deletions src/main/scala/zio/cassandra/session/cql/query/Batch.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package zio.cassandra.session.cql.query

import com.datastax.oss.driver.api.core.cql.{ BatchStatementBuilder, BatchType }
import zio.Task
import com.datastax.oss.driver.api.core.cql.{BatchStatementBuilder, BatchType}
import zio.cassandra.session.Session
import zio.{Has, RIO, ZIO}

class Batch(batchStatementBuilder: BatchStatementBuilder) {
def add(queries: Seq[Query[_]]) = new Batch(batchStatementBuilder.addStatements(queries.map(_.statement): _*))
def execute(session: Session): Task[Boolean] =
session.execute(batchStatementBuilder.build()).map(_.wasApplied)
def add(queries: Seq[Query[_]]) = new Batch(batchStatementBuilder.addStatements(queries.map(_.statement): _*))

def execute: RIO[Has[Session], Boolean] =
ZIO.accessM[Has[Session]] { session =>
session.get.execute(batchStatementBuilder.build()).map(_.wasApplied)
}

def config(config: BatchStatementBuilder => BatchStatementBuilder): Batch =
new Batch(config(batchStatementBuilder))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,27 @@ package zio.cassandra.session.cql.query
import com.datastax.oss.driver.api.core.cql.BoundStatement
import shapeless.HList
import shapeless.ops.hlist.Prepend
import zio.Task
import zio.cassandra.session.cql.{ Binder, Reads }
import zio.cassandra.session.Session
import zio.stream.Stream
import zio.cassandra.session.cql.{Binder, Reads}
import zio.stream.{Stream, ZStream}
import zio.{Has, RIO}

case class ParameterizedQuery[V <: HList: Binder, R: Reads] private (template: QueryTemplate[V, R], values: V) {
def +(that: String): ParameterizedQuery[V, R] = ParameterizedQuery[V, R](this.template + that, this.values)
def as[R1: Reads]: ParameterizedQuery[V, R1] = ParameterizedQuery[V, R1](template.as[R1], values)
def select(session: Session): Stream[Throwable, R] =
Stream.unwrap(template.prepare(session).map(_.applyProduct(values).select))

def selectFirst(session: Session): Task[Option[R]] =
template.prepare(session).flatMap(_.applyProduct(values).selectFirst)
def execute(session: Session): Task[Boolean] =
template.prepare(session).map(_.applyProduct(values)).flatMap(_.execute)
def +(that: String): ParameterizedQuery[V, R] = ParameterizedQuery[V, R](this.template + that, this.values)
def as[R1: Reads]: ParameterizedQuery[V, R1] = ParameterizedQuery[V, R1](template.as[R1], values)

def select: ZStream[Has[Session], Throwable, R] =
Stream.unwrap(template.prepare.map(_.applyProduct(values).select))

def selectFirst: RIO[Has[Session], Option[R]] =
template.prepare.flatMap(_.applyProduct(values).selectFirst)

def execute: RIO[Has[Session], Boolean] =
template.prepare.map(_.applyProduct(values)).flatMap(_.execute)

def config(config: BoundStatement => BoundStatement): ParameterizedQuery[V, R] =
ParameterizedQuery[V, R](template.config(config), values)

def stripMargin: ParameterizedQuery[V, R] = ParameterizedQuery[V, R](this.template.stripMargin, values)

def ++[W <: HList, Out <: HList](that: ParameterizedQuery[W, R])(implicit
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/zio/cassandra/session/cql/query/Query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ class Query[R: Reads] private[cql] (
private[cql] val statement: BoundStatement
) {
def config(statement: BoundStatement => BoundStatement) = new Query[R](session, statement(this.statement))

def select: Stream[Throwable, R] = session.select(statement).mapChunksM { chunk =>
chunk.mapM(Reads[R].read(_, 0))
}

def selectFirst: Task[Option[R]] = session.selectFirst(statement).flatMap {
case None => ZIO.none
case Some(row) =>
Reads[R].read(row, 0).map(Some(_))
}

def execute: Task[Boolean] = session.execute(statement).map(_.wasApplied)
}
17 changes: 11 additions & 6 deletions src/main/scala/zio/cassandra/session/cql/query/QueryTemplate.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,27 @@ package zio.cassandra.session.cql.query
import com.datastax.oss.driver.api.core.cql.BoundStatement
import shapeless.HList
import shapeless.ops.hlist.Prepend
import zio.Task
import zio.cassandra.session.cql.{ Binder, Reads }
import zio.cassandra.session.Session
import zio.cassandra.session.cql.{Binder, Reads}
import zio.{Has, RIO, ZIO}

import scala.annotation.nowarn

case class QueryTemplate[V <: HList: Binder, R: Reads] private[cql] (
query: String,
config: BoundStatement => BoundStatement
) {
def +(that: String): QueryTemplate[V, R] = QueryTemplate[V, R](this.query + that, config)
def as[R1: Reads]: QueryTemplate[V, R1] = QueryTemplate[V, R1](query, config)
def prepare(session: Session): Task[PreparedQuery[V, R]] =
session.prepare(query).map(new PreparedQuery(session, _, config))
def +(that: String): QueryTemplate[V, R] = QueryTemplate[V, R](this.query + that, config)
def as[R1: Reads]: QueryTemplate[V, R1] = QueryTemplate[V, R1](query, config)

def prepare: RIO[Has[Session], PreparedQuery[V, R]] =
ZIO.accessM[Has[Session]] { session =>
session.get.prepare(query).map(new PreparedQuery(session.get, _, config))
}

def config(config: BoundStatement => BoundStatement): QueryTemplate[V, R] =
QueryTemplate[V, R](this.query, this.config andThen config)

def stripMargin: QueryTemplate[V, R] = QueryTemplate[V, R](this.query.stripMargin, this.config)

def ++[W <: HList, Out <: HList](that: QueryTemplate[W, R])(implicit
Expand Down

0 comments on commit a6ac648

Please sign in to comment.