Skip to content

Latest commit

 

History

History
 
 

spark-tensorflow-connector

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

spark-tensorflow-connector

This repo contains a library for loading and storing TensorFlow records with Apache Spark. The library implements data import from the standard TensorFlow record format (TFRecords) into Spark SQL DataFrames, and data export from DataFrames to TensorFlow records.

What's new

This is the initial release of the spark-tensorflow-connector repo.

Prerequisites

  1. Apache Spark 2.0 (or later)

  2. Apache Maven

  3. TensorFlow Hadoop - Provided as Maven dependency. You can also build the latest version as described here.

Building the library

Build the library using Maven 3.3.9 or newer as shown below:

# Build TensorFlow Hadoop
cd ../../hadoop
mvn clean install

# Build Spark TensorFlow connector
cd ../spark/spark-tensorflow-connector
mvn clean install

To build the library for a different version of TensorFlow, e.g., 1.5.0, use:

# Build TensorFlow Hadoop
cd ../../hadoop
mvn versions:set -DnewVersion=1.5.0
mvn clean install

# Build Spark TensorFlow connector
cd ../spark/spark-tensorflow-connector
mvn versions:set -DnewVersion=1.5.0
mvn clean install

Using Spark Shell

Run this library in Spark using the --jars command line option in spark-shell or spark-submit. For example:

$SPARK_HOME/bin/spark-shell --jars target/spark-tensorflow-connector_2.11-1.6.0.jar

Features

This library allows reading TensorFlow records in local or distributed filesystem as Spark DataFrames. When reading TensorFlow records into Spark DataFrame, the API accepts several options:

  • load: input path to TensorFlow records. Similar to Spark can accept standard Hadoop globbing expressions.
  • schema: schema of TensorFlow records. Optional schema defined using Spark StructType. If not provided, the schema is inferred from TensorFlow records.
  • recordType: input format of TensorFlow records. By default it is Example. Possible values are:

When writing Spark DataFrame to TensorFlow records, the API accepts several options:

  • save: output path to TensorFlow records. Output path to TensorFlow records on local or distributed filesystem.
  • recordType: output format of TensorFlow records. By default it is Example. Possible values are:
  • writeLocality: determines whether the TensorFlow records are written locally on the workers or on a distributed file system. Possible values are:
    • distributed (default): the dataframe is written using Spark's default file system.
    • local: writes the content on the disks of each the Spark workers, in a partitioned manner (see details in the paragraph below).

Local mode write each of the workers stores on the local disk a subset of the data. The subset that is stored on each worker is determined by the partitioning of the Dataframe. Each of the partitions is coalesced into a single TFRecord file and written on the node where the partition lives. This is useful in the context of distributed training, in which each of the workers gets a subset of the data to work on. When this mode is activated, the path provided to the writer is interpreted as a base path that is created on each of the worker nodes, and that will be populated with data from the dataframe. For example, the following code:

myDataFrame.write.format("tfrecords").option("writeLocality", "local").save("/path")

will lead to each worker nodes to have the following files:

  • worker1: /path/part-0001.tfrecord, /path/part-0002.tfrecord, ...
  • worker2: /path/part-0042.tfrecord, ...

Schema inference

This library supports automatic schema inference when reading TensorFlow records into Spark DataFrames. Schema inference is expensive since it requires an extra pass through the data.

The schema inference rules are described in the table below:

TFRecordType Feature Type Inferred Spark Data Type
Example, SequenceExample Int64List LongType if all lists have length=1, else ArrayType(LongType)
Example, SequenceExample FloatList FloatType if all lists have length=1, else ArrayType(FloatType)
Example, SequenceExample BytesList StringType if all lists have length=1, else ArrayType(StringType)
SequenceExample FeatureList of Int64List ArrayType(ArrayType(LongType))
SequenceExample FeatureList of FloatList ArrayType(ArrayType(FloatType))
SequenceExample FeatureList of BytesList ArrayType(ArrayType(StringType))

Supported data types

The supported Spark data types are listed in the table below:

Type Spark DataTypes
Scalar IntegerType, LongType, FloatType, DoubleType, DecimalType, StringType, BinaryType
Array VectorType, ArrayType of IntegerType, LongType, FloatType, DoubleType, DecimalType, BinaryType, or StringType
Array of Arrays ArrayType of ArrayType of IntegerType, LongType, FloatType, DoubleType, DecimalType, BinaryType, or StringType

Usage Examples

The following code snippet demonstrates usage on test data.

import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

val path = "test-output.tfrecord"
val testRows: Array[Row] = Array(
new GenericRow(Array[Any](11, 1, 23L, 10.0F, 14.0, List(1.0, 2.0), "r1")),
new GenericRow(Array[Any](21, 2, 24L, 12.0F, 15.0, List(2.0, 2.0), "r2")))
val schema = StructType(List(StructField("id", IntegerType), 
                             StructField("IntegerTypeLabel", IntegerType),
                             StructField("LongTypeLabel", LongType),
                             StructField("FloatTypeLabel", FloatType),
                             StructField("DoubleTypeLabel", DoubleType),
                             StructField("VectorLabel", ArrayType(DoubleType, true)),
                             StructField("name", StringType)))
                             
val rdd = spark.sparkContext.parallelize(testRows)

//Save DataFrame as TFRecords
val df: DataFrame = spark.createDataFrame(rdd, schema)
df.write.format("tfrecords").option("recordType", "Example").save(path)

//Read TFRecords into DataFrame.
//The DataFrame schema is inferred from the TFRecords if no custom schema is provided.
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").load(path)
importedDf1.show()

//Read TFRecords into DataFrame using custom schema
val importedDf2: DataFrame = spark.read.format("tfrecords").schema(schema).load(path)
importedDf2.show()

Loading YouTube-8M dataset to Spark

Here's how to import the YouTube-8M dataset into a Spark DataFrame.

curl http://us.data.yt8m.org/1/video_level/train/train-0.tfrecord > /tmp/video_level-train-0.tfrecord
curl http://us.data.yt8m.org/1/frame_level/train/train-0.tfrecord > /tmp/frame_level-train-0.tfrecord
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types._

//Import Video-level Example dataset into DataFrame
val videoSchema = StructType(List(StructField("video_id", StringType),
                             StructField("labels", ArrayType(IntegerType, true)),
                             StructField("mean_rgb", ArrayType(FloatType, true)),
                             StructField("mean_audio", ArrayType(FloatType, true))))
val videoDf: DataFrame = spark.read.format("tfrecords").schema(videoSchema).option("recordType", "Example").load("file:///tmp/video_level-train-0.tfrecord")
videoDf.show()
videoDf.write.format("tfrecords").option("recordType", "Example").save("youtube-8m-video.tfrecord")
val importedDf1: DataFrame = spark.read.format("tfrecords").option("recordType", "Example").schema(videoSchema).load("youtube-8m-video.tfrecord")
importedDf1.show()

//Import Frame-level SequenceExample dataset into DataFrame
val frameSchema = StructType(List(StructField("video_id", StringType),
                             StructField("labels", ArrayType(IntegerType, true)),
                             StructField("rgb", ArrayType(ArrayType(StringType, true),true)),
                             StructField("audio", ArrayType(ArrayType(StringType, true),true))))
val frameDf: DataFrame = spark.read.format("tfrecords").schema(frameSchema).option("recordType", "SequenceExample").load("file:///tmp/frame_level-train-0.tfrecord")
frameDf.show()
frameDf.write.format("tfrecords").option("recordType", "SequenceExample").save("youtube-8m-frame.tfrecord")
val importedDf2: DataFrame = spark.read.format("tfrecords").option("recordType", "SequenceExample").schema(frameSchema).load("youtube-8m-frame.tfrecord")
importedDf2.show()