Skip to content
Snippets Groups Projects
Commit 2e8c6ca4 authored by vidmantas zemleris's avatar vidmantas zemleris Committed by Michael Armbrust
Browse files

[SPARK-6994] Allow to fetch field values by name in sql.Row

It looked weird that up to now there was no way in Spark's Scala API to access fields of `DataFrame/sql.Row` by name, only by their index.

This tries to solve this issue.

Author: vidmantas zemleris <vidmantas@vinted.com>

Closes #5573 from vidma/features/row-with-named-fields and squashes the following commits:

6145ae3 [vidmantas zemleris] [SPARK-6994][SQL] Allow to fetch field values by name on Row
9564ebb [vidmantas zemleris] [SPARK-6994][SQL] Add fieldIndex to schema (StructType)
parent 04bf34e3
No related branches found
No related tags found
No related merge requests found
......@@ -306,6 +306,38 @@ trait Row extends Serializable {
*/
def getAs[T](i: Int): T = apply(i).asInstanceOf[T]
/**
* Returns the value of a given fieldName.
*
* @throws UnsupportedOperationException when schema is not defined.
* @throws IllegalArgumentException when fieldName do not exist.
* @throws ClassCastException when data type does not match.
*/
def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName))
/**
* Returns the index of a given field name.
*
* @throws UnsupportedOperationException when schema is not defined.
* @throws IllegalArgumentException when fieldName do not exist.
*/
def fieldIndex(name: String): Int = {
throw new UnsupportedOperationException("fieldIndex on a Row without schema is undefined.")
}
/**
* Returns a Map(name -> value) for the requested fieldNames
*
* @throws UnsupportedOperationException when schema is not defined.
* @throws IllegalArgumentException when fieldName do not exist.
* @throws ClassCastException when data type does not match.
*/
def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = {
fieldNames.map { name =>
name -> getAs[T](name)
}.toMap
}
override def toString(): String = s"[${this.mkString(",")}]"
/**
......
......@@ -181,6 +181,8 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
/** No-arg constructor for serialization. */
protected def this() = this(null, null)
override def fieldIndex(name: String): Int = schema.fieldIndex(name)
}
class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
......
......@@ -1025,6 +1025,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
/**
* Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
......@@ -1049,6 +1050,14 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
StructType(fields.filter(f => names.contains(f.name)))
}
/**
* Returns index of a given field
*/
def fieldIndex(name: String): Int = {
nameToIndex.getOrElse(name,
throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
}
protected[sql] def toAttributes: Seq[AttributeReference] =
map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema}
import org.apache.spark.sql.types._
import org.scalatest.{Matchers, FunSpec}
class RowTest extends FunSpec with Matchers {
val schema = StructType(
StructField("col1", StringType) ::
StructField("col2", StringType) ::
StructField("col3", IntegerType) :: Nil)
val values = Array("value1", "value2", 1)
val sampleRow: Row = new GenericRowWithSchema(values, schema)
val noSchemaRow: Row = new GenericRow(values)
describe("Row (without schema)") {
it("throws an exception when accessing by fieldName") {
intercept[UnsupportedOperationException] {
noSchemaRow.fieldIndex("col1")
}
intercept[UnsupportedOperationException] {
noSchemaRow.getAs("col1")
}
}
}
describe("Row (with schema)") {
it("fieldIndex(name) returns field index") {
sampleRow.fieldIndex("col1") shouldBe 0
sampleRow.fieldIndex("col3") shouldBe 2
}
it("getAs[T] retrieves a value by fieldname") {
sampleRow.getAs[String]("col1") shouldBe "value1"
sampleRow.getAs[Int]("col3") shouldBe 1
}
it("Accessing non existent field throws an exception") {
intercept[IllegalArgumentException] {
sampleRow.getAs[String]("non_existent")
}
}
it("getValuesMap() retrieves values of multiple fields as a Map(field -> value)") {
val expected = Map(
"col1" -> "value1",
"col2" -> "value2"
)
sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected
}
}
}
......@@ -56,6 +56,19 @@ class DataTypeSuite extends FunSuite {
}
}
test("extract field index from a StructType") {
val struct = StructType(
StructField("a", LongType) ::
StructField("b", FloatType) :: Nil)
assert(struct.fieldIndex("a") === 0)
assert(struct.fieldIndex("b") === 1)
intercept[IllegalArgumentException] {
struct.fieldIndex("non_existent")
}
}
def checkDataTypeJsonRepr(dataType: DataType): Unit = {
test(s"JSON - $dataType") {
assert(DataType.fromJson(dataType.json) === dataType)
......
......@@ -62,4 +62,14 @@ class RowSuite extends FunSuite {
val de = instance.deserialize(ser).asInstanceOf[Row]
assert(de === row)
}
test("get values by field name on Row created via .toDF") {
val row = Seq((1, Seq(1))).toDF("a", "b").first()
assert(row.getAs[Int]("a") === 1)
assert(row.getAs[Seq[Int]]("b") === Seq(1))
intercept[IllegalArgumentException]{
row.getAs[Int]("c")
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment