From 811b08b4988b904345a158172f1c0c845e7bfd7f Mon Sep 17 00:00:00 2001 From: Sergey Rublev Date: Thu, 17 Mar 2022 16:06:33 +0300 Subject: [PATCH 1/3] Handle null values in UDT --- src/it/resources/migration/1__test_tables.cql | 8 ++- .../zio/cassandra/session/cql/CqlSpec.scala | 71 ++++++++++++++----- .../session/cql/CassandraTypeMapper.scala | 3 + .../cassandra/session/cql/FromUdtValue.scala | 19 ++++- .../zio/cassandra/session/cql/Reads.scala | 20 +----- .../session/cql/UnexpectedNullValue.scala | 42 +++++++++++ 6 files changed, 122 insertions(+), 41 deletions(-) create mode 100644 src/main/scala/zio/cassandra/session/cql/UnexpectedNullValue.scala diff --git a/src/it/resources/migration/1__test_tables.cql b/src/it/resources/migration/1__test_tables.cql index 856d173..97021ff 100644 --- a/src/it/resources/migration/1__test_tables.cql +++ b/src/it/resources/migration/1__test_tables.cql @@ -1,12 +1,14 @@ create table tests.test_data( id bigint, data text, + count int, + dataset frozen>, PRIMARY KEY (id) ); -insert into tests.test_data (id, data) values (0, null); -insert into tests.test_data (id, data) values (1, 'one'); -insert into tests.test_data (id, data) values (2, 'two'); +insert into tests.test_data (id, data, count, dataset) values (0, null, null, null); +insert into tests.test_data (id, data, count, dataset) values (1, 'one', 10, {}); +insert into tests.test_data (id, data, count, dataset) values (2, 'two', 20, {201}); insert into tests.test_data (id, data) values (3, 'three'); create table tests.test_data_multiple_keys( diff --git a/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala b/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala index cb46107..c7faa42 100644 --- a/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala +++ b/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala @@ -4,12 +4,12 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel import zio.cassandra.session.Session import zio.duration._ import zio.stream.Stream -import zio.test.Assertion.{isLeft, isSubtype} +import zio.test.Assertion.{ isLeft, isSubtype } import zio.test.TestAspect.ignore import zio.test._ -import zio.{Chunk, Task, ZIO} +import zio.{ Chunk, Task, ZIO } -import java.time.{LocalDate, LocalTime} +import java.time.{ LocalDate, LocalTime } import java.util.UUID import java.util.concurrent.atomic.AtomicInteger @@ -321,18 +321,30 @@ object CqlSpec { } yield assertTrue(result.isDefined && result.get == data) }, suite("handle NULL values")( - testM("return None if a type is Option") { + testM("return None for Option[String") { for { session <- ZIO.service[Session] result <- cql"select data FROM tests.test_data WHERE id = 0".as[Option[String]].selectFirst(session) } yield assertTrue(result.isDefined && result.get.isEmpty) }, - testM("raise error if a type is not an Option") { + testM("raise error for String(nin-primitive)") { for { session <- ZIO.service[Session] result <- cql"select data FROM tests.test_data WHERE id = 0".as[String].selectFirst(session).either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) }, + testM("raise error for Int(primitive)") { + for { + session <- ZIO.service[Session] + result <- cql"select count FROM tests.test_data WHERE id = 0".as[Int].selectFirst(session).either + } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) + }, + testM("raise error for Set(collection)") { + for { + session <- ZIO.service[Session] + result <- cql"select dataset FROM tests.test_data WHERE id = 0".as[Set[Int]].selectFirst(session).either + } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) + }, testM("return value for field in case class have Option type") { for { session <- ZIO.service[Session] @@ -372,10 +384,10 @@ object CqlSpec { .either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) }, - testM("return None when inner value is null for optional type") { + testM("return None when udt field value is null for optional type") { val data = PersonOptAttribute( personAttributeIdxCounter.incrementAndGet(), - OptBasicInfo(Some(160.0), None, Some(Set(1))) + OptBasicInfo(None, None, None) ) for { @@ -387,37 +399,60 @@ object CqlSpec { .selectFirst(session) } yield assertTrue(result.contains(data)) }, - testM("return None when inner set is null") { + testM("raise error if udt field value is mapped to String(non-primitive)") { val data = - PersonOptAttribute(personAttributeIdxCounter.incrementAndGet(), OptBasicInfo(Some(160.0), None, None)) + PersonOptAttribute( + personAttributeIdxCounter.incrementAndGet(), + OptBasicInfo(Some(160.0), None, Some(Set(1))) + ) for { session <- ZIO.service[Session] _ <- - cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, {weight:160.0,height:NULL,datapoints:NULL})" + cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" .execute(session) result <- cql"SELECT person_id, info FROM tests.person_attributes WHERE person_id = ${data.personId}" - .as[PersonOptAttribute] + .as[PersonAttribute] .selectFirst(session) - } yield assertTrue(result.contains(data)) - } @@ ignore, // todo: why datapoints which is frozen> returns as Some(Set()) here? - testM("raise error when inner non-optional value is null for non-optional type") { + .either + } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) + }, + testM("raise error if udt field value is mapped to Double(primitive)") { val data = PersonOptAttribute( personAttributeIdxCounter.incrementAndGet(), - OptBasicInfo(Some(160.0), None, Some(Set(1))) + OptBasicInfo(None, Some("tall"), Some(Set(1))) ) for { session <- ZIO.service[Session] - _ <- cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" - .execute(session) + _ <- + cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" + .execute(session) + result <- cql"SELECT person_id, info FROM tests.person_attributes WHERE person_id = ${data.personId}" + .as[PersonAttribute] + .selectFirst(session) + .either + } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) + }, + testM("raise error if udt field value is mapped to Double(primitive)") { + val data = + PersonOptAttribute( + personAttributeIdxCounter.incrementAndGet(), + OptBasicInfo(Some(160.0), Some("tall"), None) + ) + + for { + session <- ZIO.service[Session] + _ <- + cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" + .execute(session) result <- cql"SELECT person_id, info FROM tests.person_attributes WHERE person_id = ${data.personId}" .as[PersonAttribute] .selectFirst(session) .either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) - } @@ ignore // fixme + } ) ) ) diff --git a/src/main/scala/zio/cassandra/session/cql/CassandraTypeMapper.scala b/src/main/scala/zio/cassandra/session/cql/CassandraTypeMapper.scala index b5efe53..15a9081 100644 --- a/src/main/scala/zio/cassandra/session/cql/CassandraTypeMapper.scala +++ b/src/main/scala/zio/cassandra/session/cql/CassandraTypeMapper.scala @@ -21,6 +21,7 @@ trait CassandraTypeMapper[Scala] { def classType: Class[Cassandra] def toCassandra(in: Scala, dataType: DataType): Cassandra def fromCassandra(in: Cassandra, dataType: DataType): Scala + def allowNullable: Boolean = false } object CassandraTypeMapper { @@ -246,5 +247,7 @@ object CassandraTypeMapper { override def fromCassandra(in: Cassandra, dataType: DataType): Option[A] = Option(in).map(ev.fromCassandra(_, dataType)) + + override def allowNullable: Boolean = true } } diff --git a/src/main/scala/zio/cassandra/session/cql/FromUdtValue.scala b/src/main/scala/zio/cassandra/session/cql/FromUdtValue.scala index 12a8dfc..82a2082 100644 --- a/src/main/scala/zio/cassandra/session/cql/FromUdtValue.scala +++ b/src/main/scala/zio/cassandra/session/cql/FromUdtValue.scala @@ -1,5 +1,6 @@ package zio.cassandra.session.cql +import com.datastax.oss.driver.api.core.cql.Row import com.datastax.oss.driver.api.core.data.{ GettableByIndex, UdtValue } import zio.cassandra.session.cql.FromUdtValue.{ make, makeWithFieldName } @@ -23,7 +24,16 @@ object FromUdtValue extends LowerPriorityFromUdtValue with LowestPriorityFromUdt def deriveReads[A](implicit ev: FromUdtValue.Object[A]): Reads[A] = (row: GettableByIndex, index: Int) => { val udtValue = row.getUdtValue(index) - ev.convert(FieldName.Unused, udtValue) + try ev.convert(FieldName.Unused, udtValue) + catch { + case UnexpectedNullValueInUdt.NullValueInUdt(udtValue, fieldName) => + throw new UnexpectedNullValueInUdt( + row.asInstanceOf[Row], + index, + udtValue, + fieldName + ) // FIXME: get rid of .asInstanceOf + } } // only allowed to summon fully built out FromUdtValue instances which are built by Shapeless machinery @@ -71,7 +81,12 @@ trait LowerPriorityFromUdtValue { ev: CassandraTypeMapper[A] ): FromUdtValue[A] = makeWithFieldName[A] { (fieldName, udtValue) => - ev.fromCassandra(udtValue.get(fieldName, ev.classType), udtValue.getType(fieldName)) + if (udtValue.isNull(fieldName)) { + if (ev.allowNullable) + None.asInstanceOf[A] + else throw UnexpectedNullValueInUdt.NullValueInUdt(udtValue, fieldName) + } else + ev.fromCassandra(udtValue.get(fieldName, ev.classType), udtValue.getType(fieldName)) } } diff --git a/src/main/scala/zio/cassandra/session/cql/Reads.scala b/src/main/scala/zio/cassandra/session/cql/Reads.scala index 056d447..5a48c00 100644 --- a/src/main/scala/zio/cassandra/session/cql/Reads.scala +++ b/src/main/scala/zio/cassandra/session/cql/Reads.scala @@ -17,7 +17,7 @@ trait Reads[T] { self => def read(row: GettableByIndex, index: Int): Task[T] = if (row.isNull(index)) { - Task.fail(new UnexpectedNullValue(row, index)) + Task.fail(UnexpectedNullValueInColumn(row.asInstanceOf[Row], index)) } else { Task(readUnsafe(row, index)) } @@ -28,22 +28,6 @@ trait Reads[T] { self => (row: GettableByIndex, index: Int) => f(self.readUnsafe(row, index)) } -class UnexpectedNullValue(row: GettableByIndex, index: Int) extends RuntimeException() { - override def getMessage: String = - row match { - case row: Row => - val cl = row.getColumnDefinitions.get(index) - val table = cl.getTable.toString - val column = cl.getName.toString - val keyspace = cl.getKeyspace.toString - val tpe = cl.getType.asCql(true, true) - - s"Read NULL value from column $column with type $tpe at $keyspace.$table. Row ${row.getFormattedContents}" - case _ => - s"Read NULL value from column at index $index" - } -} - object Reads extends ReadsLowerPriority with ReadsLowestPriority { def apply[T](implicit r: Reads[T]): Reads[T] = r @@ -148,7 +132,7 @@ trait ReadsLowestPriority { override def read(row: GettableByIndex, index: Int): Task[A] = if (row.isNull(index)) { - Task.fail(new UnexpectedNullValue(row, index)) + Task.fail(UnexpectedNullValueInColumn(row.asInstanceOf[Row], index)) } else { val tpe = row.getType(index) if (tpe.isInstanceOf[UserDefinedType]) { diff --git a/src/main/scala/zio/cassandra/session/cql/UnexpectedNullValue.scala b/src/main/scala/zio/cassandra/session/cql/UnexpectedNullValue.scala new file mode 100644 index 0000000..afda58e --- /dev/null +++ b/src/main/scala/zio/cassandra/session/cql/UnexpectedNullValue.scala @@ -0,0 +1,42 @@ +package zio.cassandra.session.cql + +import com.datastax.oss.driver.api.core.cql.Row +import com.datastax.oss.driver.api.core.data.UdtValue + +import scala.util.control.NoStackTrace + +sealed trait UnexpectedNullValue extends Throwable + +case class UnexpectedNullValueInColumn(row: Row, index: Int) extends RuntimeException() with UnexpectedNullValue { + override def getMessage: String = { + val cl = row.getColumnDefinitions.get(index) + val table = cl.getTable.toString + val column = cl.getName.toString + val keyspace = cl.getKeyspace.toString + val tpe = cl.getType.asCql(true, true) + + s"Read NULL value from $keyspace.$table column $column expected $tpe. Row ${row.getFormattedContents}" + } +} + +case class UnexpectedNullValueInUdt(row: Row, index: Int, udt: UdtValue, fieldName: String) + extends RuntimeException() + with UnexpectedNullValue { + override def getMessage: String = { + val cl = row.getColumnDefinitions.get(index) + val table = cl.getTable.toString + val column = cl.getName.toString + val keyspace = cl.getKeyspace.toString + val tpe = cl.getType.asCql(true, true) + + val udtTpe = udt.getType(fieldName) + + s"Read NULL value from $keyspace.$table inside UDT column $column with type $tpe. NULL value in $fieldName, expected type $udtTpe. Row ${row.getFormattedContents}" + } + +} + +object UnexpectedNullValueInUdt { + private[cql] case class NullValueInUdt(udtValue: UdtValue, fieldName: String) extends NoStackTrace +} + From 0486d2c6209ef2699d06dea536b179576e2b7e29 Mon Sep 17 00:00:00 2001 From: Sergey Rublev Date: Thu, 17 Mar 2022 23:27:33 +0300 Subject: [PATCH 2/3] Improve ergonomics, move session to Environment type of ZIO --- .../zio/cassandra/session/cql/CqlSpec.scala | 136 +++++++----------- .../cassandra/session/cql/query/Batch.scala | 14 +- .../cql/query/ParameterizedQuery.scala | 28 ++-- .../cassandra/session/cql/query/Query.scala | 3 + .../session/cql/query/QueryTemplate.scala | 17 ++- 5 files changed, 92 insertions(+), 106 deletions(-) diff --git a/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala b/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala index c7faa42..1f6315a 100644 --- a/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala +++ b/src/it/scala/zio/cassandra/session/cql/CqlSpec.scala @@ -4,12 +4,12 @@ import com.datastax.oss.driver.api.core.ConsistencyLevel import zio.cassandra.session.Session import zio.duration._ import zio.stream.Stream -import zio.test.Assertion.{ isLeft, isSubtype } +import zio.test.Assertion.{isLeft, isSubtype} import zio.test.TestAspect.ignore import zio.test._ -import zio.{ Chunk, Task, ZIO } +import zio.{Chunk, Has, RIO, ZIO} -import java.time.{ LocalDate, LocalTime } +import java.time.{LocalDate, LocalTime} import java.util.UUID import java.util.concurrent.atomic.AtomicInteger @@ -77,22 +77,20 @@ object CqlSpec { val cqlSuite = suite("cql suite")( testM("interpolated select template should return data from migration") { for { - session <- ZIO.service[Session] prepared <- cqlt"select data FROM tests.test_data WHERE id in ${Put[List[Long]]}" .as[String] .config(_.setTimeout(1.seconds)) - .prepare(session) + .prepare query = prepared(List[Long](1, 2, 3)) results <- query.select.runCollect } yield assertTrue(results == Chunk("one", "two", "three")) }, testM("interpolated select template should return tuples from migration with multiple binding") { for { - session <- ZIO.service[Session] query <- cqlt"select data FROM tests.test_data_multiple_keys WHERE id1 = ${Put[Long]} and id2 = ${Put[Int]}" .as[String] - .prepare(session) + .prepare results <- query(1L, 2).config(_.setExecutionProfileName("default")).select.runCollect } yield assertTrue(results == Chunk("one-two")) }, @@ -100,25 +98,22 @@ object CqlSpec { "interpolated select template should return tuples from migration with multiple binding and margin stripped" ) { for { - session <- ZIO.service[Session] query <- cqlt"""select data FROM tests.test_data_multiple_keys - |WHERE id1 = ${Put[Long]} and id2 = ${Put[Int]}""".stripMargin.as[String].prepare(session) + |WHERE id1 = ${Put[Long]} and id2 = ${Put[Int]}""".stripMargin.as[String].prepare results <- query(1L, 2).config(_.setExecutionProfileName("default")).select.runCollect } yield assertTrue(results == Chunk("one-two")) }, testM("interpolated select template should return data case class from migration") { for { - session <- ZIO.service[Session] prepared <- - cqlt"select id, data FROM tests.test_data WHERE id in ${Put[List[Long]]}".as[Data].prepare(session) + cqlt"select id, data FROM tests.test_data WHERE id in ${Put[List[Long]]}".as[Data].prepare query = prepared(List[Long](1, 2, 3)) results <- query.select.runCollect } yield assertTrue(results == Chunk(Data(1, "one"), Data(2, "two"), Data(3, "three"))) }, testM("interpolated select template should be reusable") { for { - session <- ZIO.service[Session] - query <- cqlt"select data FROM tests.test_data WHERE id = ${Put[Long]}".as[String].prepare(session) + query <- cqlt"select data FROM tests.test_data WHERE id = ${Put[Long]}".as[String].prepare result <- Stream.fromIterable(Seq(1L, 2L, 3L)).flatMap(i => query(i).select).runCollect } yield assertTrue(result == Chunk("one", "two", "three")) }, @@ -128,24 +123,21 @@ object CqlSpec { .as[String] .config(_.setConsistencyLevel(ConsistencyLevel.ALL)) for { - session <- ZIO.service[Session] - results <- getDataByIds(List(1, 2, 3)).select(session).runCollect + results <- getDataByIds(List(1, 2, 3)).select.runCollect } yield assertTrue(results == Chunk("one", "two", "three")) }, testM("interpolated select should return tuples from migration") { def getAllByIds(ids: List[Long]) = cql"select id, data FROM tests.test_data WHERE id in $ids".as[(Long, String)] for { - session <- ZIO.service[Session] - results <- getAllByIds(List(1, 2, 3)).config(_.setQueryTimestamp(0L)).select(session).runCollect + results <- getAllByIds(List(1, 2, 3)).config(_.setQueryTimestamp(0L)).select.runCollect } yield assertTrue(results == Chunk((1L, "one"), (2L, "two"), (3L, "three"))) }, testM("interpolated select should return tuples from migration with multiple binding") { def getAllByIds(id1: Long, id2: Int) = cql"select data FROM tests.test_data_multiple_keys WHERE id1 = $id1 and id2 = $id2".as[String] for { - session <- ZIO.service[Session] - results <- getAllByIds(1, 2).select(session).runCollect + results <- getAllByIds(1, 2).select.runCollect } yield assertTrue(results == Chunk("one-two")) }, testM("interpolated select should return tuples from migration with multiple binding and margin stripped") { @@ -153,16 +145,14 @@ object CqlSpec { cql"""select data FROM tests.test_data_multiple_keys |WHERE id1 = $id1 and id2 = $id2""".stripMargin.as[String] for { - session <- ZIO.service[Session] - results <- getAllByIds(1, 2).select(session).runCollect + results <- getAllByIds(1, 2).select.runCollect } yield assertTrue(results == Chunk("one-two")) }, testM("interpolated select should return data case class from migration") { def getIds(ids: List[Long]) = cql"select id, data FROM tests.test_data WHERE id in $ids".as[Data] for { - session <- ZIO.service[Session] - results <- getIds(List(1, 2, 3)).select(session).runCollect + results <- getIds(List(1, 2, 3)).select.runCollect } yield assertTrue(results == Chunk(Data(1, "one"), Data(2, "two"), Data(3, "three"))) }, testM( @@ -172,12 +162,11 @@ object CqlSpec { PersonAttribute(personAttributeIdxCounter.incrementAndGet(), BasicInfo(180.0, "tall", Set(1, 2, 3, 4, 5))) for { - session <- ZIO.service[Session] _ <- cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" - .execute(session) + .execute result <- cql"SELECT person_id, info FROM tests.person_attributes WHERE person_id = ${data.personId}" .as[PersonAttribute] - .select(session) + .select .runCollect } yield assertTrue(result.length == 1 && result.head == data) }, @@ -185,23 +174,22 @@ object CqlSpec { val dataRow1 = CollectionTestRow(1, Map("2" -> UUID.randomUUID()), Set(1, 2, 3), Option(List(LocalDate.now()))) val dataRow2 = CollectionTestRow(2, Map("3" -> UUID.randomUUID()), Set(4, 5, 6), None) - def insert(session: Session)(data: CollectionTestRow): Task[Boolean] = + def insert(data: CollectionTestRow): RIO[Has[Session], Boolean] = cql"INSERT INTO tests.test_collection (id, maptest, settest, listtest) VALUES (${data.id}, ${data.maptest}, ${data.settest}, ${data.listtest})" - .execute(session) + .execute - def retrieve(session: Session, id: Int, ids: Int*): Task[Chunk[CollectionTestRow]] = { + def retrieve(id: Int, ids: Int*): RIO[Has[Session], Chunk[CollectionTestRow]] = { val allIds = id :: ids.toList cql"SELECT id, maptest, settest, listtest FROM tests.test_collection WHERE id IN $allIds" .as[CollectionTestRow] - .select(session) + .select .runCollect } for { - session <- ZIO.service[Session] - _ <- ZIO.foreachPar_(List(dataRow1, dataRow2))(insert(session)) - res1 <- retrieve(session, dataRow1.id) - res2 <- retrieve(session, dataRow2.id) + _ <- ZIO.foreachPar_(List(dataRow1, dataRow2))(insert) + res1 <- retrieve(dataRow1.id) + res2 <- retrieve(dataRow2.id) } yield assertTrue(res1.length == 1 && res1.head == dataRow1) && assertTrue( res2.length == 1 && res2.head == dataRow2 ) @@ -256,11 +244,10 @@ object CqlSpec { ) for { - session <- ZIO.service[Session] - _ <- cql"INSERT INTO tests.heavily_nested_udt_table (id, data) VALUES (${row.id}, ${row.data})".execute(session) + _ <- cql"INSERT INTO tests.heavily_nested_udt_table (id, data) VALUES (${row.id}, ${row.data})".execute actual <- cql"SELECT id, data FROM tests.heavily_nested_udt_table WHERE id = ${row.id}" .as[TableContainingExampleCollectionNestedUdtType] - .select(session) + .select .runCollect } yield assertTrue(actual.length == 1 && actual.head == row) }, @@ -275,27 +262,23 @@ object CqlSpec { ) ) ) - def insert(session: Session) = - cql"INSERT INTO tests.heavily_nested_prim_table (id, data) VALUES (${row.id}, ${row.data})".execute( - session - ) + def insert = + cql"INSERT INTO tests.heavily_nested_prim_table (id, data) VALUES (${row.id}, ${row.data})".execute - def retrieve(session: Session) = cql"SELECT id, data FROM tests.heavily_nested_prim_table WHERE id = ${row.id}" + def retrieve = cql"SELECT id, data FROM tests.heavily_nested_prim_table WHERE id = ${row.id}" .as[TableContainingExampleNestedPrimitiveType] - .select(session) + .select .runCollect for { - session <- ZIO.service[Session] - _ <- insert(session) - actual <- retrieve(session) + _ <- insert + actual <- retrieve } yield assertTrue(actual.length == 1 && actual.head == row) }, testM("interpolated select should bind constants") { val query = cql"select data FROM tests.test_data WHERE id = ${1L}".as[String] for { - session <- ZIO.service[Session] - result <- query.select(session).runCollect + result <- query.select.runCollect } yield assertTrue(result == Chunk("one")) }, testM("cqlConst allows you to interpolate on what is usually not possible with cql strings") { @@ -310,51 +293,44 @@ object CqlSpec { def where(personId: Int) = cql" WHERE person_id = $personId" - def insert(session: Session, data: PersonAttribute) = + def insert(data: PersonAttribute) = (cql"INSERT INTO " ++ keyspace ++ table ++ cql" (person_id, info) VALUES (${data.personId}, ${data.info})") - .execute(session) + .execute for { - session <- ZIO.service[Session] - _ <- insert(session, data) - result <- (selectFrom ++ keyspace ++ table ++ where(data.personId)).as[PersonAttribute].selectFirst(session) + _ <- insert(data) + result <- (selectFrom ++ keyspace ++ table ++ where(data.personId)).as[PersonAttribute].selectFirst } yield assertTrue(result.isDefined && result.get == data) }, suite("handle NULL values")( testM("return None for Option[String") { for { - session <- ZIO.service[Session] - result <- cql"select data FROM tests.test_data WHERE id = 0".as[Option[String]].selectFirst(session) + result <- cql"select data FROM tests.test_data WHERE id = 0".as[Option[String]].selectFirst } yield assertTrue(result.isDefined && result.get.isEmpty) }, testM("raise error for String(nin-primitive)") { for { - session <- ZIO.service[Session] - result <- cql"select data FROM tests.test_data WHERE id = 0".as[String].selectFirst(session).either + result <- cql"select data FROM tests.test_data WHERE id = 0".as[String].selectFirst.either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) }, testM("raise error for Int(primitive)") { for { - session <- ZIO.service[Session] - result <- cql"select count FROM tests.test_data WHERE id = 0".as[Int].selectFirst(session).either + result <- cql"select count FROM tests.test_data WHERE id = 0".as[Int].selectFirst.either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) }, testM("raise error for Set(collection)") { for { - session <- ZIO.service[Session] - result <- cql"select dataset FROM tests.test_data WHERE id = 0".as[Set[Int]].selectFirst(session).either + result <- cql"select dataset FROM tests.test_data WHERE id = 0".as[Set[Int]].selectFirst.either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) }, testM("return value for field in case class have Option type") { for { - session <- ZIO.service[Session] - row <- cql"select id, data FROM tests.test_data WHERE id = 0".as[OptData].selectFirst(session) + row <- cql"select id, data FROM tests.test_data WHERE id = 0".as[OptData].selectFirst } yield assertTrue(row.isDefined && row.get.data.isEmpty) }, testM("raise error if field in case class have Option type") { for { - session <- ZIO.service[Session] - result <- cql"select id, data FROM tests.test_data WHERE id = 0".as[Data].selectFirst(session).either + result <- cql"select id, data FROM tests.test_data WHERE id = 0".as[Data].selectFirst.either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) }, suite("handle NULL values with udt")( @@ -362,12 +338,11 @@ object CqlSpec { val data = OptPersonAttribute(personAttributeIdxCounter.incrementAndGet(), None) for { - session <- ZIO.service[Session] _ <- cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" - .execute(session) + .execute result <- cql"SELECT person_id, info FROM tests.person_attributes WHERE person_id = ${data.personId}" .as[OptPersonAttribute] - .select(session) + .select .runCollect } yield assertTrue(result.length == 1 && result.head == data) }, @@ -375,12 +350,11 @@ object CqlSpec { val data = OptPersonAttribute(personAttributeIdxCounter.incrementAndGet(), None) for { - session <- ZIO.service[Session] _ <- cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" - .execute(session) + .execute result <- cql"SELECT person_id, info FROM tests.person_attributes WHERE person_id = ${data.personId}" .as[PersonAttribute] - .selectFirst(session) + .selectFirst .either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) }, @@ -391,12 +365,11 @@ object CqlSpec { ) for { - session <- ZIO.service[Session] _ <- cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" - .execute(session) + .execute result <- cql"SELECT person_id, info FROM tests.person_attributes WHERE person_id = ${data.personId}" .as[PersonOptAttribute] - .selectFirst(session) + .selectFirst } yield assertTrue(result.contains(data)) }, testM("raise error if udt field value is mapped to String(non-primitive)") { @@ -407,13 +380,12 @@ object CqlSpec { ) for { - session <- ZIO.service[Session] _ <- cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" - .execute(session) + .execute result <- cql"SELECT person_id, info FROM tests.person_attributes WHERE person_id = ${data.personId}" .as[PersonAttribute] - .selectFirst(session) + .selectFirst .either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) }, @@ -425,13 +397,12 @@ object CqlSpec { ) for { - session <- ZIO.service[Session] _ <- cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" - .execute(session) + .execute result <- cql"SELECT person_id, info FROM tests.person_attributes WHERE person_id = ${data.personId}" .as[PersonAttribute] - .selectFirst(session) + .selectFirst .either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) }, @@ -443,13 +414,12 @@ object CqlSpec { ) for { - session <- ZIO.service[Session] _ <- cql"INSERT INTO tests.person_attributes (person_id, info) VALUES (${data.personId}, ${data.info})" - .execute(session) + .execute result <- cql"SELECT person_id, info FROM tests.person_attributes WHERE person_id = ${data.personId}" .as[PersonAttribute] - .selectFirst(session) + .selectFirst .either } yield assert(result)(isLeft(isSubtype[UnexpectedNullValue](Assertion.anything))) } diff --git a/src/main/scala/zio/cassandra/session/cql/query/Batch.scala b/src/main/scala/zio/cassandra/session/cql/query/Batch.scala index 6663004..372e699 100644 --- a/src/main/scala/zio/cassandra/session/cql/query/Batch.scala +++ b/src/main/scala/zio/cassandra/session/cql/query/Batch.scala @@ -1,13 +1,17 @@ package zio.cassandra.session.cql.query -import com.datastax.oss.driver.api.core.cql.{ BatchStatementBuilder, BatchType } -import zio.Task +import com.datastax.oss.driver.api.core.cql.{BatchStatementBuilder, BatchType} import zio.cassandra.session.Session +import zio.{Has, RIO, ZIO} class Batch(batchStatementBuilder: BatchStatementBuilder) { - def add(queries: Seq[Query[_]]) = new Batch(batchStatementBuilder.addStatements(queries.map(_.statement): _*)) - def execute(session: Session): Task[Boolean] = - session.execute(batchStatementBuilder.build()).map(_.wasApplied) + def add(queries: Seq[Query[_]]) = new Batch(batchStatementBuilder.addStatements(queries.map(_.statement): _*)) + + def execute: RIO[Has[Session], Boolean] = + ZIO.accessM[Has[Session]] { session => + session.get.execute(batchStatementBuilder.build()).map(_.wasApplied) + } + def config(config: BatchStatementBuilder => BatchStatementBuilder): Batch = new Batch(config(batchStatementBuilder)) } diff --git a/src/main/scala/zio/cassandra/session/cql/query/ParameterizedQuery.scala b/src/main/scala/zio/cassandra/session/cql/query/ParameterizedQuery.scala index 8c26b94..986f0d5 100644 --- a/src/main/scala/zio/cassandra/session/cql/query/ParameterizedQuery.scala +++ b/src/main/scala/zio/cassandra/session/cql/query/ParameterizedQuery.scala @@ -3,23 +3,27 @@ package zio.cassandra.session.cql.query import com.datastax.oss.driver.api.core.cql.BoundStatement import shapeless.HList import shapeless.ops.hlist.Prepend -import zio.Task -import zio.cassandra.session.cql.{ Binder, Reads } import zio.cassandra.session.Session -import zio.stream.Stream +import zio.cassandra.session.cql.{Binder, Reads} +import zio.stream.{Stream, ZStream} +import zio.{Has, RIO} case class ParameterizedQuery[V <: HList: Binder, R: Reads] private (template: QueryTemplate[V, R], values: V) { - def +(that: String): ParameterizedQuery[V, R] = ParameterizedQuery[V, R](this.template + that, this.values) - def as[R1: Reads]: ParameterizedQuery[V, R1] = ParameterizedQuery[V, R1](template.as[R1], values) - def select(session: Session): Stream[Throwable, R] = - Stream.unwrap(template.prepare(session).map(_.applyProduct(values).select)) - - def selectFirst(session: Session): Task[Option[R]] = - template.prepare(session).flatMap(_.applyProduct(values).selectFirst) - def execute(session: Session): Task[Boolean] = - template.prepare(session).map(_.applyProduct(values)).flatMap(_.execute) + def +(that: String): ParameterizedQuery[V, R] = ParameterizedQuery[V, R](this.template + that, this.values) + def as[R1: Reads]: ParameterizedQuery[V, R1] = ParameterizedQuery[V, R1](template.as[R1], values) + + def select: ZStream[Has[Session], Throwable, R] = + Stream.unwrap(template.prepare.map(_.applyProduct(values).select)) + + def selectFirst: RIO[Has[Session], Option[R]] = + template.prepare.flatMap(_.applyProduct(values).selectFirst) + + def execute: RIO[Has[Session], Boolean] = + template.prepare.map(_.applyProduct(values)).flatMap(_.execute) + def config(config: BoundStatement => BoundStatement): ParameterizedQuery[V, R] = ParameterizedQuery[V, R](template.config(config), values) + def stripMargin: ParameterizedQuery[V, R] = ParameterizedQuery[V, R](this.template.stripMargin, values) def ++[W <: HList, Out <: HList](that: ParameterizedQuery[W, R])(implicit diff --git a/src/main/scala/zio/cassandra/session/cql/query/Query.scala b/src/main/scala/zio/cassandra/session/cql/query/Query.scala index b5fa9c8..228bfea 100644 --- a/src/main/scala/zio/cassandra/session/cql/query/Query.scala +++ b/src/main/scala/zio/cassandra/session/cql/query/Query.scala @@ -11,13 +11,16 @@ class Query[R: Reads] private[cql] ( private[cql] val statement: BoundStatement ) { def config(statement: BoundStatement => BoundStatement) = new Query[R](session, statement(this.statement)) + def select: Stream[Throwable, R] = session.select(statement).mapChunksM { chunk => chunk.mapM(Reads[R].read(_, 0)) } + def selectFirst: Task[Option[R]] = session.selectFirst(statement).flatMap { case None => ZIO.none case Some(row) => Reads[R].read(row, 0).map(Some(_)) } + def execute: Task[Boolean] = session.execute(statement).map(_.wasApplied) } diff --git a/src/main/scala/zio/cassandra/session/cql/query/QueryTemplate.scala b/src/main/scala/zio/cassandra/session/cql/query/QueryTemplate.scala index 715cc77..a5b22ed 100644 --- a/src/main/scala/zio/cassandra/session/cql/query/QueryTemplate.scala +++ b/src/main/scala/zio/cassandra/session/cql/query/QueryTemplate.scala @@ -3,9 +3,9 @@ package zio.cassandra.session.cql.query import com.datastax.oss.driver.api.core.cql.BoundStatement import shapeless.HList import shapeless.ops.hlist.Prepend -import zio.Task -import zio.cassandra.session.cql.{ Binder, Reads } import zio.cassandra.session.Session +import zio.cassandra.session.cql.{Binder, Reads} +import zio.{Has, RIO, ZIO} import scala.annotation.nowarn @@ -13,12 +13,17 @@ case class QueryTemplate[V <: HList: Binder, R: Reads] private[cql] ( query: String, config: BoundStatement => BoundStatement ) { - def +(that: String): QueryTemplate[V, R] = QueryTemplate[V, R](this.query + that, config) - def as[R1: Reads]: QueryTemplate[V, R1] = QueryTemplate[V, R1](query, config) - def prepare(session: Session): Task[PreparedQuery[V, R]] = - session.prepare(query).map(new PreparedQuery(session, _, config)) + def +(that: String): QueryTemplate[V, R] = QueryTemplate[V, R](this.query + that, config) + def as[R1: Reads]: QueryTemplate[V, R1] = QueryTemplate[V, R1](query, config) + + def prepare: RIO[Has[Session], PreparedQuery[V, R]] = + ZIO.accessM[Has[Session]] { session => + session.get.prepare(query).map(new PreparedQuery(session.get, _, config)) + } + def config(config: BoundStatement => BoundStatement): QueryTemplate[V, R] = QueryTemplate[V, R](this.query, this.config andThen config) + def stripMargin: QueryTemplate[V, R] = QueryTemplate[V, R](this.query.stripMargin, this.config) def ++[W <: HList, Out <: HList](that: QueryTemplate[W, R])(implicit From 7b3ef205cc4b55290e43478bb2bb74f30b9513a1 Mon Sep 17 00:00:00 2001 From: Sergey Rublev Date: Thu, 17 Mar 2022 23:29:38 +0300 Subject: [PATCH 3/3] add Session.live based on ZLayer --- src/main/scala/zio/cassandra/session/Session.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/main/scala/zio/cassandra/session/Session.scala b/src/main/scala/zio/cassandra/session/Session.scala index 3b128e5..9a9ab60 100644 --- a/src/main/scala/zio/cassandra/session/Session.scala +++ b/src/main/scala/zio/cassandra/session/Session.scala @@ -93,6 +93,11 @@ object Session { override def keyspace: Option[CqlIdentifier] = underlying.getKeyspace.toScala } + def live: ZManaged[Has[CqlSessionBuilder], Throwable, Session] = + ZManaged.serviceWithManaged[CqlSessionBuilder] { cqlSessionBuilder => + make(cqlSessionBuilder) + } + def make(builder: => CqlSessionBuilder): TaskManaged[Session] = ZManaged .make(Task.fromCompletionStage(builder.buildAsync())) { session =>