diff --git a/README.md b/README.md index 1423c85..5e47bfb 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ The artifacts are published to [bintray](https://bintray.com/linkedin/maven/spar - Version 0.2.x targets Spark 2.4 and both Scala 2.11 and 2.12 - Version 0.3.x targets Spark 3.0 and Scala 2.12 - Version 0.4.x targets Spark 3.2 and Scala 2.12 +- Version 0.5.x targets Spark 3.2 and Scala 2.13 To use the package, please include the dependency as follows diff --git a/pom.xml b/pom.xml index 1c1af4c..dfa9cb9 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ com.linkedin.sparktfrecord spark-tfrecord_${scala.binary.version} jar - 0.4.0 + 0.5.0 spark-tfrecord https://github.com/linkedin/spark-tfrecord TensorFlow TFRecord data source for Apache Spark @@ -354,7 +354,7 @@ scala-2.13 2.13 - 2.13.13 + 2.13.8 diff --git a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializer.scala b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializer.scala index ccc8f3d..1e6c544 100644 --- a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializer.scala +++ b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializer.scala @@ -7,7 +7,7 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType, _} import org.apache.spark.unsafe.types.UTF8String import org.tensorflow.example._ -import scala.collection.JavaConverters._ +import scala.jdk.CollectionConverters._ /** * Creates a TFRecord deserializer to deserialize Tfrecord example or sequenceExample to Spark InternalRow @@ -196,7 +196,7 @@ class TFRecordDeserializer(dataSchema: StructType) { def bytesListFeature2SeqArrayByte(feature: Feature): Seq[Array[Byte]] = { require(feature != null && feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList") try { - feature.getBytesList.getValueList.asScala.map((byteArray) => byteArray.asScala.toArray.map(_.toByte)) + feature.getBytesList.getValueList.asScala.toSeq.map((byteArray) => byteArray.asScala.toArray.map(_.toByte)) } catch { case ex: Exception => diff --git a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializer.scala b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializer.scala index 40d7d98..8748f35 100644 --- a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializer.scala +++ b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializer.scala @@ -97,16 +97,16 @@ class TFRecordSerializer(dataSchema: StructType) { val arrayData = getter.getArray(ordinal) val featureOrFeatureList = elementType match { case IntegerType => - Int64ListFeature(arrayData.toIntArray().map(_.toLong)) + Int64ListFeature(arrayData.toIntArray().toSeq.map(_.toLong)) case LongType => - Int64ListFeature(arrayData.toLongArray()) + Int64ListFeature(arrayData.toLongArray().toSeq) case FloatType => - floatListFeature(arrayData.toFloatArray()) + floatListFeature(arrayData.toFloatArray().toSeq) case DoubleType => - floatListFeature(arrayData.toDoubleArray().map(_.toFloat)) + floatListFeature(arrayData.toDoubleArray().toSeq.map(_.toFloat)) case DecimalType() => val elementConverter = arrayElementConverter(elementType) @@ -117,7 +117,7 @@ class TFRecordSerializer(dataSchema: StructType) { result(idx) = null } else result(idx) = elementConverter(arrayData, idx).asInstanceOf[Decimal] } - floatListFeature(result.map(_.toFloat)) + floatListFeature(result.toSeq.map(_.toFloat)) case StringType | BinaryType => val elementConverter = arrayElementConverter(elementType) @@ -128,13 +128,13 @@ class TFRecordSerializer(dataSchema: StructType) { result(idx) = null } else result(idx) = elementConverter(arrayData, idx).asInstanceOf[Array[Byte]] } - bytesListFeature(result) + bytesListFeature(result.toSeq) // 2-dimensional array to TensorFlow "FeatureList" case ArrayType(_, _) => val elementConverter = newFeatureConverter(elementType) val featureList = FeatureList.newBuilder() - for (idx <- 0 until arrayData.numElements) { + for (idx <- 0 until arrayData.numElements()) { val feature = elementConverter(arrayData, idx).asInstanceOf[Feature] featureList.addFeature(feature) } diff --git a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema.scala b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema.scala index 3dc5cc2..c640318 100644 --- a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema.scala +++ b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema.scala @@ -18,8 +18,9 @@ package com.linkedin.spark.datasources.tfrecord import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import org.tensorflow.example.{FeatureList, SequenceExample, Example, Feature} -import scala.collection.JavaConverters._ + import scala.collection.mutable +import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe._ object TensorFlowInferSchema { diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala index aafae43..7450be4 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala @@ -273,7 +273,7 @@ class TFRecordDeserializerTest extends WordSpec with Matchers { "Test bytesListFeature2SeqArrayByte" in { val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build() val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build() - assert(deserializer.bytesListFeature2SeqArrayByte(bytesFeature).head === "str-input".getBytes.deep) + assert(deserializer.bytesListFeature2SeqArrayByte(bytesFeature).head.sameElements("str-input".getBytes)) // Throw exception if type doesn't match intercept[RuntimeException] { diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala index d1671a5..484e2cb 100644 --- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala +++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala @@ -121,11 +121,11 @@ class TFRecordSerializerTest extends WordSpec with Matchers { assert(featureMap("StrArrayLabel").getBytesList.getValueList.asScala.map(_.toStringUtf8) === Seq("r2", "r3")) assert(featureMap("BinaryLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) - assert(featureMap("BinaryLabel").getBytesList.getValue(0).toByteArray.deep == byteArray.deep) + assert(featureMap("BinaryLabel").getBytesList.getValue(0).toByteArray.sameElements(byteArray)) assert(featureMap("BinaryArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER) val binaryArrayValue = featureMap("BinaryArrayLabel").getBytesList.getValueList.asScala.map((byteArray) => byteArray.asScala.toArray.map(_.toByte)) - assert(binaryArrayValue.toArray.deep == Array(byteArray, byteArray1).deep) + binaryArrayValue.toArray should equal(Array(byteArray, byteArray1)) } "Serialize internalRow to tfrecord sequenceExample" in { @@ -199,12 +199,12 @@ class TFRecordSerializerTest extends WordSpec with Matchers { assert(featureListMap("LongArrayOfArrayLabel").getFeatureList.asScala.map( _.getInt64List.getValueList.asScala.toSeq) === longListOfLists) - assert(featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.map( + assert(featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.toSeq.map( _.getFloatList.getValueList.asScala.map(_.toFloat).toSeq) ~== floatListOfLists.map{arr => arr.toSeq}.toSeq) - assert(featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.map( + assert(featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.toSeq.map( _.getFloatList.getValueList.asScala.map(_.toDouble).toSeq) ~== doubleListOfLists.map{arr => arr.toSeq}.toSeq) - assert(featureListMap("DecimalArrayOfArrayLabel").getFeatureList.asScala.map( + assert(featureListMap("DecimalArrayOfArrayLabel").getFeatureList.asScala.toSeq.map( _.getFloatList.getValueList.asScala.map(x => Decimal(x.toDouble)).toSeq) ~== decimalListOfLists.map{arr => arr.toSeq}.toSeq) assert(featureListMap("StringArrayOfArrayLabel").getFeatureList.asScala.map( @@ -313,10 +313,10 @@ class TFRecordSerializerTest extends WordSpec with Matchers { Array(0xff.toByte, 0xd8.toByte), Array(0xff.toByte, 0xd9.toByte))) - assert(bytesFeature.getBytesList.getValueList.asScala.map(_.toByteArray.deep) === - Seq(Array(0xff.toByte, 0xd8.toByte).deep)) - assert(bytesListFeature.getBytesList.getValueList.asScala.map(_.toByteArray.deep) === - Seq(Array(0xff.toByte, 0xd8.toByte).deep, Array(0xff.toByte, 0xd9.toByte).deep)) + bytesFeature.getBytesList.getValueList.asScala.map(_.toByteArray).toArray should equal( + Array(Array(0xff.toByte, 0xd8.toByte))) + bytesListFeature.getBytesList.getValueList.asScala.map(_.toByteArray).toArray should equal( + Array(Array(0xff.toByte, 0xd8.toByte), Array(0xff.toByte, 0xd9.toByte))) } } }