Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhao-db committed Jan 6, 2025
1 parent d44fd2b commit 83f37bb
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,33 +184,37 @@ case class ToVariantObject(child: Expression)
}
}

object VariantPathParser extends RegexParsers {
// A path segment in the `VariantGet` expression represents either an object key access or an
// array index access.
type PathSegment = Either[String, Int]
// A path segment in the `VariantGet` expression represents either an object key access or an array
// index access.
sealed abstract class VariantPathSegment extends Serializable

case class ObjectExtraction(key: String) extends VariantPathSegment

case class ArrayExtraction(index: Int) extends VariantPathSegment

object VariantPathParser extends RegexParsers {
private def root: Parser[Char] = '$'

// Parse index segment like `[123]`.
private def index: Parser[PathSegment] =
private def index: Parser[VariantPathSegment] =
for {
index <- '[' ~> "\\d+".r <~ ']'
} yield {
scala.util.Right(index.toInt)
ArrayExtraction(index.toInt)
}

// Parse key segment like `.name`, `['name']`, or `["name"]`.
private def key: Parser[PathSegment] =
private def key: Parser[VariantPathSegment] =
for {
key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
"[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
} yield {
scala.util.Left(key)
ObjectExtraction(key)
}

private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key | index))
private val parser: Parser[List[VariantPathSegment]] = phrase(root ~> rep(key | index))

def parse(str: String): Option[Array[PathSegment]] = {
def parse(str: String): Option[Array[VariantPathSegment]] = {
this.parseAll(parser, str) match {
case Success(result, _) => Some(result.toArray)
case _ => None
Expand Down Expand Up @@ -349,14 +353,14 @@ case object VariantGet {
/** The actual implementation of the `VariantGet` expression. */
def variantGet(
input: VariantVal,
parsedPath: Array[VariantPathParser.PathSegment],
parsedPath: Array[VariantPathSegment],
dataType: DataType,
castArgs: VariantCastArgs): Any = {
var v = new Variant(input.getValue, input.getMetadata)
for (path <- parsedPath) {
v = path match {
case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
case ObjectExtraction(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
case ArrayExtraction(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
case _ => null
}
if (v == null) return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable.HashMap

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.variant.{VariantGet, VariantPathParser}
import org.apache.spark.sql.catalyst.expressions.variant._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Subquery}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -54,7 +54,7 @@ case class VariantMetadata(
.build()
).build()

def parsedPath(): Array[VariantPathParser.PathSegment] = {
def parsedPath(): Array[VariantPathSegment] = {
VariantPathParser.parse(path).getOrElse {
val name = if (failOnError) "variant_get" else "try_variant_get"
throw QueryExecutionErrors.invalidVariantGetPath(path, name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.variant._
import org.apache.spark.sql.catalyst.expressions.variant.VariantPathParser.PathSegment
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.RowToColumnConverter
Expand Down Expand Up @@ -56,9 +55,9 @@ case class SparkShreddedRow(row: SpecializedGetters) extends ShreddingUtils.Shre
override def numElements(): Int = row.asInstanceOf[ArrayData].numElements()
}

// The search result of a `PathSegment` in a `VariantSchema`.
// The search result of a `VariantPathSegment` in a `VariantSchema`.
case class SchemaPathSegment(
rawPath: PathSegment,
rawPath: VariantPathSegment,
// Whether this path segment is an object or array extraction.
isObject: Boolean,
// `schema.typedIdx`, if the path exists in the schema (for object extraction, the schema
Expand Down Expand Up @@ -714,11 +713,11 @@ case object SparkShreddingUtils {
// found at a certain level of the file type, then `typedIdx` will be -1 starting from
// this position, and the final `schema` will be null.
for (i <- rawPath.indices) {
val isObject = rawPath(i).isLeft
val isObject = rawPath(i).isInstanceOf[ObjectExtraction]
var typedIdx = -1
var extractionIdx = -1
rawPath(i) match {
case scala.util.Left(key) if schema != null && schema.objectSchema != null =>
case ObjectExtraction(key) if schema != null && schema.objectSchema != null =>
val fieldIdx = schema.objectSchemaMap.get(key)
if (fieldIdx != null) {
typedIdx = schema.typedIdx
Expand All @@ -727,7 +726,7 @@ case object SparkShreddingUtils {
} else {
schema = null
}
case scala.util.Right(index) if schema != null && schema.arraySchema != null =>
case ArrayExtraction(index) if schema != null && schema.arraySchema != null =>
typedIdx = schema.typedIdx
extractionIdx = index
schema = schema.arraySchema
Expand Down Expand Up @@ -770,8 +769,8 @@ case object SparkShreddingUtils {
var v = new Variant(row.getBinary(variantIdx), topLevelMetadata)
while (pathIdx < pathLen) {
v = pathList(pathIdx).rawPath match {
case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
case ObjectExtraction(key) if v.getType == Type.OBJECT => v.getFieldByKey(key)
case ArrayExtraction(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index)
case _ => null
}
if (v == null) return null
Expand Down

0 comments on commit 83f37bb

Please sign in to comment.