diff --git a/README.md b/README.md
index 1423c85..5e47bfb 100644
--- a/README.md
+++ b/README.md
@@ -11,6 +11,7 @@ The artifacts are published to [bintray](https://bintray.com/linkedin/maven/spar
- Version 0.2.x targets Spark 2.4 and both Scala 2.11 and 2.12
- Version 0.3.x targets Spark 3.0 and Scala 2.12
- Version 0.4.x targets Spark 3.2 and Scala 2.12
+- Version 0.5.x targets Spark 3.2 and Scala 2.13
To use the package, please include the dependency as follows
diff --git a/pom.xml b/pom.xml
index 1c1af4c..dfa9cb9 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
com.linkedin.sparktfrecord
spark-tfrecord_${scala.binary.version}
jar
- 0.4.0
+ 0.5.0
spark-tfrecord
https://github.com/linkedin/spark-tfrecord
TensorFlow TFRecord data source for Apache Spark
@@ -354,7 +354,7 @@
scala-2.13
2.13
- 2.13.13
+ 2.13.8
diff --git a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializer.scala b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializer.scala
index ccc8f3d..1e6c544 100644
--- a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializer.scala
+++ b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializer.scala
@@ -7,7 +7,7 @@ import org.apache.spark.sql.types.{DecimalType, DoubleType, _}
import org.apache.spark.unsafe.types.UTF8String
import org.tensorflow.example._
-import scala.collection.JavaConverters._
+import scala.jdk.CollectionConverters._
/**
* Creates a TFRecord deserializer to deserialize Tfrecord example or sequenceExample to Spark InternalRow
@@ -196,7 +196,7 @@ class TFRecordDeserializer(dataSchema: StructType) {
def bytesListFeature2SeqArrayByte(feature: Feature): Seq[Array[Byte]] = {
require(feature != null && feature.getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER, "Feature must be of type ByteList")
try {
- feature.getBytesList.getValueList.asScala.map((byteArray) => byteArray.asScala.toArray.map(_.toByte))
+ feature.getBytesList.getValueList.asScala.toSeq.map((byteArray) => byteArray.asScala.toArray.map(_.toByte))
}
catch {
case ex: Exception =>
diff --git a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializer.scala b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializer.scala
index 40d7d98..8748f35 100644
--- a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializer.scala
+++ b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializer.scala
@@ -97,16 +97,16 @@ class TFRecordSerializer(dataSchema: StructType) {
val arrayData = getter.getArray(ordinal)
val featureOrFeatureList = elementType match {
case IntegerType =>
- Int64ListFeature(arrayData.toIntArray().map(_.toLong))
+ Int64ListFeature(arrayData.toIntArray().toSeq.map(_.toLong))
case LongType =>
- Int64ListFeature(arrayData.toLongArray())
+ Int64ListFeature(arrayData.toLongArray().toSeq)
case FloatType =>
- floatListFeature(arrayData.toFloatArray())
+ floatListFeature(arrayData.toFloatArray().toSeq)
case DoubleType =>
- floatListFeature(arrayData.toDoubleArray().map(_.toFloat))
+ floatListFeature(arrayData.toDoubleArray().toSeq.map(_.toFloat))
case DecimalType() =>
val elementConverter = arrayElementConverter(elementType)
@@ -117,7 +117,7 @@ class TFRecordSerializer(dataSchema: StructType) {
result(idx) = null
} else result(idx) = elementConverter(arrayData, idx).asInstanceOf[Decimal]
}
- floatListFeature(result.map(_.toFloat))
+ floatListFeature(result.toSeq.map(_.toFloat))
case StringType | BinaryType =>
val elementConverter = arrayElementConverter(elementType)
@@ -128,13 +128,13 @@ class TFRecordSerializer(dataSchema: StructType) {
result(idx) = null
} else result(idx) = elementConverter(arrayData, idx).asInstanceOf[Array[Byte]]
}
- bytesListFeature(result)
+ bytesListFeature(result.toSeq)
// 2-dimensional array to TensorFlow "FeatureList"
case ArrayType(_, _) =>
val elementConverter = newFeatureConverter(elementType)
val featureList = FeatureList.newBuilder()
- for (idx <- 0 until arrayData.numElements) {
+ for (idx <- 0 until arrayData.numElements()) {
val feature = elementConverter(arrayData, idx).asInstanceOf[Feature]
featureList.addFeature(feature)
}
diff --git a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema.scala b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema.scala
index 3dc5cc2..c640318 100644
--- a/src/main/scala/com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema.scala
+++ b/src/main/scala/com/linkedin/spark/datasources/tfrecord/TensorFlowInferSchema.scala
@@ -18,8 +18,9 @@ package com.linkedin.spark.datasources.tfrecord
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.tensorflow.example.{FeatureList, SequenceExample, Example, Feature}
-import scala.collection.JavaConverters._
+
import scala.collection.mutable
+import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe._
object TensorFlowInferSchema {
diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala
index aafae43..7450be4 100644
--- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala
+++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordDeserializerTest.scala
@@ -273,7 +273,7 @@ class TFRecordDeserializerTest extends WordSpec with Matchers {
"Test bytesListFeature2SeqArrayByte" in {
val bytesList = BytesList.newBuilder().addValue(ByteString.copyFrom("str-input".getBytes)).build()
val bytesFeature = Feature.newBuilder().setBytesList(bytesList).build()
- assert(deserializer.bytesListFeature2SeqArrayByte(bytesFeature).head === "str-input".getBytes.deep)
+ assert(deserializer.bytesListFeature2SeqArrayByte(bytesFeature).head.sameElements("str-input".getBytes))
// Throw exception if type doesn't match
intercept[RuntimeException] {
diff --git a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala
index d1671a5..484e2cb 100644
--- a/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala
+++ b/src/test/scala/com/linkedin/spark/datasources/tfrecord/TFRecordSerializerTest.scala
@@ -121,11 +121,11 @@ class TFRecordSerializerTest extends WordSpec with Matchers {
assert(featureMap("StrArrayLabel").getBytesList.getValueList.asScala.map(_.toStringUtf8) === Seq("r2", "r3"))
assert(featureMap("BinaryLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER)
- assert(featureMap("BinaryLabel").getBytesList.getValue(0).toByteArray.deep == byteArray.deep)
+ assert(featureMap("BinaryLabel").getBytesList.getValue(0).toByteArray.sameElements(byteArray))
assert(featureMap("BinaryArrayLabel").getKindCase.getNumber == Feature.BYTES_LIST_FIELD_NUMBER)
val binaryArrayValue = featureMap("BinaryArrayLabel").getBytesList.getValueList.asScala.map((byteArray) => byteArray.asScala.toArray.map(_.toByte))
- assert(binaryArrayValue.toArray.deep == Array(byteArray, byteArray1).deep)
+ binaryArrayValue.toArray should equal(Array(byteArray, byteArray1))
}
"Serialize internalRow to tfrecord sequenceExample" in {
@@ -199,12 +199,12 @@ class TFRecordSerializerTest extends WordSpec with Matchers {
assert(featureListMap("LongArrayOfArrayLabel").getFeatureList.asScala.map(
_.getInt64List.getValueList.asScala.toSeq) === longListOfLists)
- assert(featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.map(
+ assert(featureListMap("FloatArrayOfArrayLabel").getFeatureList.asScala.toSeq.map(
_.getFloatList.getValueList.asScala.map(_.toFloat).toSeq) ~== floatListOfLists.map{arr => arr.toSeq}.toSeq)
- assert(featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.map(
+ assert(featureListMap("DoubleArrayOfArrayLabel").getFeatureList.asScala.toSeq.map(
_.getFloatList.getValueList.asScala.map(_.toDouble).toSeq) ~== doubleListOfLists.map{arr => arr.toSeq}.toSeq)
- assert(featureListMap("DecimalArrayOfArrayLabel").getFeatureList.asScala.map(
+ assert(featureListMap("DecimalArrayOfArrayLabel").getFeatureList.asScala.toSeq.map(
_.getFloatList.getValueList.asScala.map(x => Decimal(x.toDouble)).toSeq) ~== decimalListOfLists.map{arr => arr.toSeq}.toSeq)
assert(featureListMap("StringArrayOfArrayLabel").getFeatureList.asScala.map(
@@ -313,10 +313,10 @@ class TFRecordSerializerTest extends WordSpec with Matchers {
Array(0xff.toByte, 0xd8.toByte),
Array(0xff.toByte, 0xd9.toByte)))
- assert(bytesFeature.getBytesList.getValueList.asScala.map(_.toByteArray.deep) ===
- Seq(Array(0xff.toByte, 0xd8.toByte).deep))
- assert(bytesListFeature.getBytesList.getValueList.asScala.map(_.toByteArray.deep) ===
- Seq(Array(0xff.toByte, 0xd8.toByte).deep, Array(0xff.toByte, 0xd9.toByte).deep))
+ bytesFeature.getBytesList.getValueList.asScala.map(_.toByteArray).toArray should equal(
+ Array(Array(0xff.toByte, 0xd8.toByte)))
+ bytesListFeature.getBytesList.getValueList.asScala.map(_.toByteArray).toArray should equal(
+ Array(Array(0xff.toByte, 0xd8.toByte), Array(0xff.toByte, 0xd9.toByte)))
}
}
}