Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PPL geoip function #871

Merged
merged 10 commits into from
Dec 19, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
}

protected def createGeoIpTestTable(testTable: String): Unit = {
sql(
s"""
sql(s"""
| CREATE TABLE $testTable
| (
| ip STRING,
Expand All @@ -782,8 +781,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| USING $tableType $tableOptions
|""".stripMargin)

sql(
s"""
sql(s"""
| INSERT INTO $testTable
| VALUES ('66.249.157.90', true),
| ('2a09:bac2:19f8:2ac3::', true),
Expand All @@ -793,8 +791,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
}

protected def createGeoIpTable(): Unit = {
sql(
s"""
sql(s"""
| CREATE TABLE geoip
| (
| cidr STRING,
Expand All @@ -813,8 +810,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| USING $tableType $tableOptions
|""".stripMargin)

sql(
s"""
sql(s"""
| INSERT INTO geoip
| VALUES (
| '66.249.157.0/24',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLGeoipITSuite
extends QueryTest
extends QueryTest
with LogicalPlanTestUtils
with FlintPPLSuite
with StreamTest {
Expand All @@ -34,8 +34,7 @@ class FlintSparkPPLGeoipITSuite
}

test("test geoip with no parameters") {
kenrickyap marked this conversation as resolved.
Show resolved Hide resolved
val frame = sql(
s"""
val frame = sql(s"""
| source = $testTable| where isValid = true | eval a = geoip(ip) | fields ip, a
| """.stripMargin)

Expand All @@ -44,37 +43,52 @@ class FlintSparkPPLGeoipITSuite

// Define the expected results
val expectedResults: Array[Row] = Array(
Row("66.249.157.90", Row("JM", "Jamaica", "North America", "14", "Saint Catherine Parish", "Portmore", "America/Jamaica", "17.9686,-76.8827")),
Row("2a09:bac2:19f8:2ac3::", Row("CA", "Canada", "North America", "PE", "Prince Edward Island", "Charlottetown", "America/Halifax", "46.2396,-63.1355"))
)
Row(
"66.249.157.90",
Row(
"JM",
"Jamaica",
"North America",
"14",
"Saint Catherine Parish",
"Portmore",
"America/Jamaica",
"17.9686,-76.8827")),
Row(
"2a09:bac2:19f8:2ac3::",
Row(
"CA",
"Canada",
"North America",
"PE",
"Prince Edward Island",
"Charlottetown",
"America/Halifax",
"46.2396,-63.1355")))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))
}

test("test geoip with one parameters") {
val frame = sql(
s"""
val frame = sql(s"""
| source = $testTable| where isValid = true | eval a = geoip(ip, country_name) | fields ip, a
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("66.249.157.90", "Jamaica"),
Row("2a09:bac2:19f8:2ac3::", "Canada")
)
val expectedResults: Array[Row] =
Array(Row("66.249.157.90", "Jamaica"), Row("2a09:bac2:19f8:2ac3::", "Canada"))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))
}

test("test geoip with multiple parameters") {
val frame = sql(
s"""
val frame = sql(s"""
| source = $testTable| where isValid = true | eval a = geoip(ip, country_name, city_name) | fields ip, a
| """.stripMargin)

Expand All @@ -83,8 +97,7 @@ class FlintSparkPPLGeoipITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("66.249.157.90", Row("Jamaica", "Portmore")),
Row("2a09:bac2:19f8:2ac3::", Row("Canada", "Charlottetown"))
)
Row("2a09:bac2:19f8:2ac3::", Row("Canada", "Charlottetown")))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
Expand Down
kenrickyap marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Join,
import org.apache.spark.sql.types.DataTypes

class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
extends SparkFunSuite
extends SparkFunSuite
with PlanTest
with LogicalPlanTestUtils
with Matchers {
Expand All @@ -28,20 +28,18 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
private val pplParser = new PPLSyntaxParser()

private def getGeoIpQueryPlan(
ipAddress: UnresolvedAttribute,
left : LogicalPlan,
right : LogicalPlan,
projectionProperties : Alias
) : LogicalPlan = {
ipAddress: UnresolvedAttribute,
left: LogicalPlan,
right: LogicalPlan,
projectionProperties: Alias): LogicalPlan = {
val joinPlan = getJoinPlan(ipAddress, left, right)
getProjection(joinPlan, projectionProperties)
}

private def getJoinPlan(
ipAddress: UnresolvedAttribute,
left : LogicalPlan,
right : LogicalPlan
) : LogicalPlan = {
ipAddress: UnresolvedAttribute,
left: LogicalPlan,
right: LogicalPlan): LogicalPlan = {
val is_ipv4 = ScalaUDF(
kenrickyap marked this conversation as resolved.
Show resolved Hide resolved
SerializableUdf.geoIpUtils.isIpv4,
DataTypes.BooleanType,
Expand All @@ -50,8 +48,7 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
Option.empty,
Option.apply("is_ipv4"),
false,
true
)
true)
val ip_to_int = ScalaUDF(
SerializableUdf.geoIpUtils.ipToInt,
DataTypes.createDecimalType(38, 0),
Expand All @@ -60,29 +57,34 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
Option.empty,
Option.apply("ip_to_int"),
false,
true
)
true)

val t1 = SubqueryAlias("t1", left)
val t2 = SubqueryAlias("t2", right)

val joinCondition = And(
And(
GreaterThanOrEqual(ip_to_int, UnresolvedAttribute("t2.ip_range_start")),
LessThan(ip_to_int, UnresolvedAttribute("t2.ip_range_end"))
),
EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4"))
)
LessThan(ip_to_int, UnresolvedAttribute("t2.ip_range_end"))),
EqualTo(is_ipv4, UnresolvedAttribute("t2.ipv4")))
Join(t1, t2, LeftOuter, Some(joinCondition), JoinHint.NONE)
}

private def getProjection(joinPlan : LogicalPlan, projectionProperties : Alias) : LogicalPlan = {
private def getProjection(joinPlan: LogicalPlan, projectionProperties: Alias): LogicalPlan = {
val projection = Project(Seq(UnresolvedStar(None), projectionProperties), joinPlan)
val dropList = Seq(
"t2.country_iso_code", "t2.country_name", "t2.continent_name",
"t2.region_iso_code", "t2.region_name", "t2.city_name",
"t2.time_zone", "t2.location", "t2.cidr", "t2.ip_range_start", "t2.ip_range_end", "t2.ipv4"
).map(UnresolvedAttribute(_))
"t2.country_iso_code",
"t2.country_name",
"t2.continent_name",
"t2.region_iso_code",
"t2.region_name",
"t2.city_name",
"t2.time_zone",
"t2.location",
"t2.cidr",
"t2.ip_range_start",
"t2.ip_range_end",
"t2.ipv4").map(UnresolvedAttribute(_))
DataFrameDropColumns(dropList, projection)
}

Expand All @@ -98,16 +100,24 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
val sourceTable = UnresolvedRelation(seq("users"))
val geoTable = UnresolvedRelation(seq("geoip"))

val projectionStruct = CreateNamedStruct(Seq(
Literal("country_iso_code"), UnresolvedAttribute("t2.country_iso_code"),
Literal("country_name"), UnresolvedAttribute("t2.country_name"),
Literal("continent_name"), UnresolvedAttribute("t2.continent_name"),
Literal("region_iso_code"), UnresolvedAttribute("t2.region_iso_code"),
Literal("region_name"), UnresolvedAttribute("t2.region_name"),
Literal("city_name"), UnresolvedAttribute("t2.city_name"),
Literal("time_zone"), UnresolvedAttribute("t2.time_zone"),
Literal("location"), UnresolvedAttribute("t2.location")
))
val projectionStruct = CreateNamedStruct(
Seq(
Literal("country_iso_code"),
UnresolvedAttribute("t2.country_iso_code"),
Literal("country_name"),
UnresolvedAttribute("t2.country_name"),
Literal("continent_name"),
UnresolvedAttribute("t2.continent_name"),
Literal("region_iso_code"),
UnresolvedAttribute("t2.region_iso_code"),
Literal("region_name"),
UnresolvedAttribute("t2.region_name"),
Literal("city_name"),
UnresolvedAttribute("t2.city_name"),
Literal("time_zone"),
UnresolvedAttribute("t2.time_zone"),
Literal("location"),
UnresolvedAttribute("t2.location")))
val structProjection = Alias(projectionStruct, "a")()

val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection)
Expand Down Expand Up @@ -135,7 +145,6 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}


test("test geoip function - ipAddress col exist in geoip table") {
val context = new CatalystPlanContext

Expand All @@ -158,7 +167,7 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
test("test geoip function - duplicate parameters") {
val context = new CatalystPlanContext

val exception = intercept[IllegalStateException]{
val exception = intercept[IllegalStateException] {
planTransformer.visit(
plan(pplParser, "source=t1 | eval a = geoip(cidr, country_name, country_name)"),
context)
Expand Down Expand Up @@ -197,10 +206,12 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite
val ipAddress = UnresolvedAttribute("ip_address")
val sourceTable = UnresolvedRelation(seq("users"))
val geoTable = UnresolvedRelation(seq("geoip"))
val projectionStruct = CreateNamedStruct(Seq(
Literal("country_name"), UnresolvedAttribute("t2.country_name"),
Literal("location"), UnresolvedAttribute("t2.location")
))
val projectionStruct = CreateNamedStruct(
Seq(
Literal("country_name"),
UnresolvedAttribute("t2.country_name"),
Literal("location"),
UnresolvedAttribute("t2.location")))
val structProjection = Alias(projectionStruct, "a")()

val geoIpPlan = getGeoIpQueryPlan(ipAddress, sourceTable, geoTable, structProjection)
Expand All @@ -214,7 +225,9 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite

val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code)"),
plan(
pplParser,
"source=t | eval a = geoip(ip_address, country_iso_code), b = geoip(ip_address, region_iso_code)"),
context)

val ipAddress = UnresolvedAttribute("ip_address")
Expand All @@ -237,7 +250,9 @@ class PPLLogicalPlanGeoipFunctionTranslatorTestSuite

val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | eval a = geoip(ip_address, time_zone), b = rand(), c = geoip(ip_address, region_name)"),
plan(
pplParser,
"source=t | eval a = geoip(ip_address, time_zone), b = rand(), c = geoip(ip_address, region_name)"),
context)

val ipAddress = UnresolvedAttribute("ip_address")
Expand Down
Loading