Skip to content

Commit 28c52e9

Browse files
authored
Always resolve current field name in SparkSQL when creating Row objects inside of arrays (#2158) (#2166)
* Resolve current field every time when creating Row objects in arrays. * Expand tests to ensure no breakages * Mirror changes to the sql-20 source root * Throw an exception for the case where there is no current field when creating a map for an array. (cherry picked from commit 668fcc1)
1 parent 3d1b299 commit 28c52e9

File tree

4 files changed

+278
-33
lines changed

4 files changed

+278
-33
lines changed

spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala

Lines changed: 123 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ import scala.collection.mutable.ArrayBuffer
3535
import org.apache.spark.SparkConf
3636
import org.apache.spark.SparkContext
3737
import org.apache.spark.SparkException
38-
import org.apache.spark.sql.Row
39-
import org.apache.spark.sql.SQLContext
40-
import org.apache.spark.sql.SaveMode
38+
import org.apache.spark.sql.{Row, SQLContext, SaveMode, SparkSession}
4139
import org.apache.spark.sql.types.ArrayType
4240
import org.apache.spark.sql.types.Decimal
4341
import org.apache.spark.sql.types.DecimalType
@@ -89,7 +87,6 @@ import com.esotericsoftware.kryo.io.{Output => KryoOutput}
8987
import org.apache.spark.rdd.RDD
9088

9189
import javax.xml.bind.DatatypeConverter
92-
import org.apache.spark.sql.SparkSession
9390
import org.elasticsearch.hadoop.EsAssume
9491
import org.elasticsearch.hadoop.TestData
9592
import org.elasticsearch.hadoop.cfg.ConfigurationOptions
@@ -293,6 +290,15 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
293290

294291
df.take(1).foreach(println)
295292
assertEquals(1, df.count())
293+
val head = df.head()
294+
val arr = head.getSeq[Row](0);
295+
assertThat(arr.size, is(2))
296+
assertEquals(arr(0).getString(0), "1")
297+
assertEquals(arr(0).getString(1), "2")
298+
assertEquals(arr(1).getString(0), "unu")
299+
assertEquals(arr(1).getString(1), "doi")
300+
val topLevel = head.getString(1)
301+
assertEquals(topLevel, "root")
296302
}
297303

298304
@Test
@@ -347,9 +353,30 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
347353
val mapping = SchemaUtilsTestable.rowInfo(cfgSettings)
348354

349355
val df = sqc.read.options(newCfg).format("org.elasticsearch.spark.sql").load(target)
350-
df.printSchema()
351-
df.take(1).foreach(println)
356+
// df.printSchema()
357+
// df.take(1).foreach(println)
352358
assertEquals(1, df.count())
359+
360+
val document = df.take(1).head
361+
assertEquals(text, document.getString(1)) // .foo
362+
assertEquals(0L, document.getLong(2)) // .level
363+
assertEquals(text, document.getString(3)) // .level1
364+
365+
val bar = document.getStruct(0) // .bar
366+
assertEquals(10L, bar.getLong(1)) // .bar.foo2
367+
assertEquals(1L, bar.getLong(2)) // .bar.level
368+
assertEquals(2L, bar.getLong(3)) // .bar.level2
369+
370+
val barbar = bar.getStruct(0) // .bar.bar
371+
assertEquals(2L, barbar.getLong(1)) // .bar.bar.level
372+
assertTrue(barbar.getBoolean(2)) // .bar.bar.level3
373+
374+
val barbarbar = barbar.getSeq[Row](0) // .bar.bar.bar
375+
assertEquals(2, barbarbar.size)
376+
val barbarbar0bar = barbarbar.head // .bar.bar.bar.[0]
377+
assertEquals(1L, barbarbar0bar.getLong(0)) // .bar.bar.bar.[0].bar
378+
val barbarbar1bar = barbarbar.last // .bar.bar.bar.[1]
379+
assertEquals(2L, barbarbar1bar.getLong(0)) // .bar.bar.bar.[1].bar
353380
}
354381

355382
@Test
@@ -369,6 +396,19 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
369396
df.printSchema()
370397
df.take(1).foreach(println)
371398
assertEquals(1, df.count())
399+
400+
val document = df.head()
401+
assertEquals(5L, document.getLong(0)) // .foo
402+
val nested = document.getStruct(1) // .nested
403+
val bar = nested.getSeq[Row](0) // .nested.bar
404+
assertEquals(2, bar.size)
405+
val bar1 = bar.head // .nested.bar.[1]
406+
assertEquals(20L, bar1.getLong(0)) // .nested.bar.[1].age
407+
assertEquals(new Timestamp(115, 0, 1, 0, 0, 0, 0), bar1.getTimestamp(1)) // .nested.bar.[1].date
408+
val bar2 = bar.last // .nested.bar.[2]
409+
assertEquals(20L, bar2.getLong(0)) // .nested.bar.[2].age
410+
assertEquals(new Timestamp(115, 0, 1, 0, 0, 0, 0), bar2.getTimestamp(1)) // .nested.bar.[2].date
411+
assertEquals("now", nested.getString(1)) // .nested.what
372412
}
373413

374414
@Test
@@ -1563,17 +1603,20 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
15631603
assertEquals("array", bar.dataType.typeName)
15641604
val scores = bar.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]("scores")
15651605
assertEquals("array", scores.dataType.typeName)
1566-
1567-
val head = df.head
1568-
val foo = head.getSeq[Long](0)
1569-
assertEquals(5, foo(0))
1570-
assertEquals(6, foo(1))
1571-
// nested
1572-
val nested = head.getStruct(1)
1573-
assertEquals("now", nested.getString(1))
1574-
val nestedDate = nested.getSeq[Row](0)
1575-
val nestedScores = nestedDate(0).getSeq[Long](1)
1576-
assertEquals(2l, nestedScores(1))
1606+
1607+
val document = df.head
1608+
val foo = document.getSeq[Long](0) // .foo
1609+
assertEquals(5, foo(0)) // .foo[0]
1610+
assertEquals(6, foo(1)) // .foo[1]
1611+
val nested = document.getStruct(1) // .nested
1612+
assertEquals("now", nested.getString(1)) // .nested.what
1613+
1614+
val nestedBar = nested.getSeq[Row](0) // .nested.bar.[0]
1615+
val nestedScores = nestedBar(0).getSeq[Long](1) // .nested.bar.[0].scores
1616+
assertEquals(2l, nestedScores(1)) // .nested.bar.[0].scores.[1]
1617+
1618+
val nestedScores2 = nestedBar(1).getSeq[Long](1) // .nested.bar.[1].scores
1619+
assertEquals(4l, nestedScores2(1)) // .nested.bar.[1].scores.[1]
15771620
}
15781621

15791622
//@Test
@@ -2282,6 +2325,69 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
22822325
assertThat(nested.size, is(3))
22832326
assertEquals(nested(0).getString(0), "anne")
22842327
assertEquals(nested(0).getLong(1), 6)
2328+
assertEquals(nested(1).getString(0), "bob")
2329+
assertEquals(nested(1).getLong(1), 100)
2330+
assertEquals(nested(2).getString(0), "charlie")
2331+
assertEquals(nested(2).getLong(1), 15)
2332+
}
2333+
2334+
@Test
2335+
def testNestedWithEmptyObjectAtTail() {
2336+
val mapping = wrapMapping("data",
2337+
s"""{
2338+
| "properties": {
2339+
| "name": { "type": "$keyword" },
2340+
| "nested-field": {
2341+
| "type": "nested",
2342+
| "properties": {
2343+
| "key": {"type": "$keyword"},
2344+
| "subnested-field": {
2345+
| "type": "nested",
2346+
| "properties": {
2347+
| "subkey": {"type": "$keyword"}
2348+
| }
2349+
| }
2350+
| }
2351+
| }
2352+
| }
2353+
|}
2354+
""".stripMargin)
2355+
2356+
val index = wrapIndex("sparksql-test-nested-empty-object-at-tail")
2357+
val typed = "data"
2358+
val (target, _) = makeTargets(index, typed)
2359+
RestUtils.touch(index)
2360+
RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))
2361+
2362+
val data = """{"name":"nested-empty-object","nested-field":[{"key": "value1","subnested-field":{}},{"key": "value2"}]}""".stripMargin
2363+
2364+
sc.makeRDD(Seq(data)).saveJsonToEs(target)
2365+
val df = sqc.read.format("es").load(index)
2366+
2367+
println(df.schema.treeString)
2368+
2369+
val dataType = df.schema("nested-field").dataType
2370+
assertEquals("array", dataType.typeName)
2371+
val array = dataType.asInstanceOf[ArrayType]
2372+
assertEquals("struct", array.elementType.typeName)
2373+
val struct = array.elementType.asInstanceOf[StructType]
2374+
assertEquals("string", struct("key").dataType.typeName)
2375+
assertEquals("array", struct("subnested-field").dataType.typeName)
2376+
2377+
val subArrayType = struct("subnested-field").dataType
2378+
assertEquals("array", subArrayType.typeName)
2379+
val subArray = subArrayType.asInstanceOf[ArrayType]
2380+
assertEquals("struct", subArray.elementType.typeName)
2381+
val subStruct = subArray.elementType.asInstanceOf[StructType]
2382+
assertEquals("string", subStruct("subkey").dataType.typeName)
2383+
2384+
val head = df.head()
2385+
val nested = head.getSeq[Row](1) // .nested-field
2386+
assertThat(nested.size, is(2))
2387+
assertEquals(nested(0).getString(0), "value1") // .nested-field.[0].key matches
2388+
assertEquals(nested(0).getSeq(1).size, 1) // .nested-field.[0].subnested-field is singleton list
2389+
assertNull(nested(0).getSeq[Row](1).head.get(0)) // .nested-field.[0].subnested-field.[0] is empty objet
2390+
assertEquals(nested(1).getString(0), "value2") // .nested-field.[1].key matches
22852391
}
22862392

22872393

spark/sql-20/src/main/scala/org/elasticsearch/spark/sql/ScalaEsRowValueReader.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,23 @@ class ScalaRowValueReader extends ScalaValueReader with RowValueReader with Valu
5959
else {
6060
val rowOrd =
6161
if (inArray) {
62-
if (rowColumnsMap.contains(sparkRowField)) {
63-
rowColumns(sparkRowField)
62+
// Recollect the current field name. If the last thing we read before a new object in a list was an empty
63+
// object, we won't be able to find the correct row order for the next row being created.
64+
// Example: foo:[{bar: baz, qux:{}},{bar:bizzy}]
65+
// ^ ^____This could break because parser think's that
66+
// \_________ this field is the current one and loads the wrong row order
67+
// By re-resolving the current field, we can avoid this edge case, because that is managed by a stack in the
68+
// superclass instead of the local sparkRowField.
69+
var latestRowField = if (getCurrentField == null) null else getCurrentField.getFieldName
70+
if (latestRowField == null) {
71+
throw new IllegalStateException(
72+
"No field information could be found while creating map for " +
73+
s"array: previous field [${sparkRowField}], row order [${currentArrayRowOrder}]"
74+
)
75+
}
76+
77+
if (rowColumnsMap.contains(latestRowField)) {
78+
rowColumns(latestRowField)
6479
}
6580
else {
6681
currentArrayRowOrder

spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala

Lines changed: 121 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,15 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
294294

295295
df.take(1).foreach(println)
296296
assertEquals(1, df.count())
297+
val head = df.head()
298+
val arr = head.getSeq[Row](0);
299+
assertThat(arr.size, is(2))
300+
assertEquals(arr(0).getString(0), "1")
301+
assertEquals(arr(0).getString(1), "2")
302+
assertEquals(arr(1).getString(0), "unu")
303+
assertEquals(arr(1).getString(1), "doi")
304+
val topLevel = head.getString(1)
305+
assertEquals(topLevel, "root")
297306
}
298307

299308
@Test
@@ -348,9 +357,30 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
348357
val mapping = SchemaUtilsTestable.rowInfo(cfgSettings)
349358

350359
val df = sqc.read.options(newCfg).format("org.elasticsearch.spark.sql").load(target)
351-
df.printSchema()
352-
df.take(1).foreach(println)
360+
// df.printSchema()
361+
// df.take(1).foreach(println)
353362
assertEquals(1, df.count())
363+
364+
val document = df.take(1).head
365+
assertEquals(text, document.getString(1)) // .foo
366+
assertEquals(0L, document.getLong(2)) // .level
367+
assertEquals(text, document.getString(3)) // .level1
368+
369+
val bar = document.getStruct(0) // .bar
370+
assertEquals(10L, bar.getLong(1)) // .bar.foo2
371+
assertEquals(1L, bar.getLong(2)) // .bar.level
372+
assertEquals(2L, bar.getLong(3)) // .bar.level2
373+
374+
val barbar = bar.getStruct(0) // .bar.bar
375+
assertEquals(2L, barbar.getLong(1)) // .bar.bar.level
376+
assertTrue(barbar.getBoolean(2)) // .bar.bar.level3
377+
378+
val barbarbar = barbar.getSeq[Row](0) // .bar.bar.bar
379+
assertEquals(2, barbarbar.size)
380+
val barbarbar0bar = barbarbar.head // .bar.bar.bar.[0]
381+
assertEquals(1L, barbarbar0bar.getLong(0)) // .bar.bar.bar.[0].bar
382+
val barbarbar1bar = barbarbar.last // .bar.bar.bar.[1]
383+
assertEquals(2L, barbarbar1bar.getLong(0)) // .bar.bar.bar.[1].bar
354384
}
355385

356386
@Test
@@ -370,6 +400,19 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
370400
df.printSchema()
371401
df.take(1).foreach(println)
372402
assertEquals(1, df.count())
403+
404+
val document = df.head()
405+
assertEquals(5L, document.getLong(0)) // .foo
406+
val nested = document.getStruct(1) // .nested
407+
val bar = nested.getSeq[Row](0) // .nested.bar
408+
assertEquals(2, bar.size)
409+
val bar1 = bar.head // .nested.bar.[1]
410+
assertEquals(20L, bar1.getLong(0)) // .nested.bar.[1].age
411+
assertEquals(new Timestamp(115, 0, 1, 0, 0, 0, 0), bar1.getTimestamp(1)) // .nested.bar.[1].date
412+
val bar2 = bar.last // .nested.bar.[2]
413+
assertEquals(20L, bar2.getLong(0)) // .nested.bar.[2].age
414+
assertEquals(new Timestamp(115, 0, 1, 0, 0, 0, 0), bar2.getTimestamp(1)) // .nested.bar.[2].date
415+
assertEquals("now", nested.getString(1)) // .nested.what
373416
}
374417

375418
@Test
@@ -1565,16 +1608,19 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
15651608
val scores = bar.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType]("scores")
15661609
assertEquals("array", scores.dataType.typeName)
15671610

1568-
val head = df.head
1569-
val foo = head.getSeq[Long](0)
1570-
assertEquals(5, foo(0))
1571-
assertEquals(6, foo(1))
1572-
// nested
1573-
val nested = head.getStruct(1)
1574-
assertEquals("now", nested.getString(1))
1575-
val nestedDate = nested.getSeq[Row](0)
1576-
val nestedScores = nestedDate(0).getSeq[Long](1)
1577-
assertEquals(2l, nestedScores(1))
1611+
val document = df.head
1612+
val foo = document.getSeq[Long](0) // .foo
1613+
assertEquals(5, foo(0)) // .foo[0]
1614+
assertEquals(6, foo(1)) // .foo[1]
1615+
val nested = document.getStruct(1) // .nested
1616+
assertEquals("now", nested.getString(1)) // .nested.what
1617+
1618+
val nestedBar = nested.getSeq[Row](0) // .nested.bar.[0]
1619+
val nestedScores = nestedBar(0).getSeq[Long](1) // .nested.bar.[0].scores
1620+
assertEquals(2l, nestedScores(1)) // .nested.bar.[0].scores.[1]
1621+
1622+
val nestedScores2 = nestedBar(1).getSeq[Long](1) // .nested.bar.[1].scores
1623+
assertEquals(4l, nestedScores2(1)) // .nested.bar.[1].scores.[1]
15781624
}
15791625

15801626
//@Test
@@ -2283,6 +2329,69 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
22832329
assertThat(nested.size, is(3))
22842330
assertEquals(nested(0).getString(0), "anne")
22852331
assertEquals(nested(0).getLong(1), 6)
2332+
assertEquals(nested(1).getString(0), "bob")
2333+
assertEquals(nested(1).getLong(1), 100)
2334+
assertEquals(nested(2).getString(0), "charlie")
2335+
assertEquals(nested(2).getLong(1), 15)
2336+
}
2337+
2338+
@Test
2339+
def testNestedWithEmptyObjectAtTail() {
2340+
val mapping = wrapMapping("data",
2341+
s"""{
2342+
| "properties": {
2343+
| "name": { "type": "$keyword" },
2344+
| "nested-field": {
2345+
| "type": "nested",
2346+
| "properties": {
2347+
| "key": {"type": "$keyword"},
2348+
| "subnested-field": {
2349+
| "type": "nested",
2350+
| "properties": {
2351+
| "subkey": {"type": "$keyword"}
2352+
| }
2353+
| }
2354+
| }
2355+
| }
2356+
| }
2357+
|}
2358+
""".stripMargin)
2359+
2360+
val index = wrapIndex("sparksql-test-nested-empty-object-at-tail")
2361+
val typed = "data"
2362+
val (target, _) = makeTargets(index, typed)
2363+
RestUtils.touch(index)
2364+
RestUtils.putMapping(index, typed, mapping.getBytes(StringUtils.UTF_8))
2365+
2366+
val data = """{"name":"nested-empty-object","nested-field":[{"key": "value1","subnested-field":{}},{"key": "value2"}]}""".stripMargin
2367+
2368+
sc.makeRDD(Seq(data)).saveJsonToEs(target)
2369+
val df = sqc.read.format("es").load(index)
2370+
2371+
println(df.schema.treeString)
2372+
2373+
val dataType = df.schema("nested-field").dataType
2374+
assertEquals("array", dataType.typeName)
2375+
val array = dataType.asInstanceOf[ArrayType]
2376+
assertEquals("struct", array.elementType.typeName)
2377+
val struct = array.elementType.asInstanceOf[StructType]
2378+
assertEquals("string", struct("key").dataType.typeName)
2379+
assertEquals("array", struct("subnested-field").dataType.typeName)
2380+
2381+
val subArrayType = struct("subnested-field").dataType
2382+
assertEquals("array", subArrayType.typeName)
2383+
val subArray = subArrayType.asInstanceOf[ArrayType]
2384+
assertEquals("struct", subArray.elementType.typeName)
2385+
val subStruct = subArray.elementType.asInstanceOf[StructType]
2386+
assertEquals("string", subStruct("subkey").dataType.typeName)
2387+
2388+
val head = df.head()
2389+
val nested = head.getSeq[Row](1) // .nested-field
2390+
assertThat(nested.size, is(2))
2391+
assertEquals(nested(0).getString(0), "value1") // .nested-field.[0].key matches
2392+
assertEquals(nested(0).getSeq(1).size, 1) // .nested-field.[0].subnested-field is singleton list
2393+
assertNull(nested(0).getSeq[Row](1).head.get(0)) // .nested-field.[0].subnested-field.[0] is empty objet
2394+
assertEquals(nested(1).getString(0), "value2") // .nested-field.[1].key matches
22862395
}
22872396

22882397

0 commit comments

Comments
 (0)