diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bd7e51c3b5100874862b3dc51a20a8847363ff86..22c05a247942206fb940656b877ce6be7f710360 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2153,6 +2153,10 @@ private[spark] object Utils extends Logging { conf.getInt("spark.executor.instances", 0) == 0 } + def tryWithResource[R <: Closeable, T](createResource: => R)(f: R => T): T = { + val resource = createResource + try f.apply(resource) finally resource.close() + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 0770fab0ae901ed53f89a076bfc7d97d5fb1ca16..8c9853e628d2c156f921c6e1c79fbcd5189d477b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType} import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils import scala.util.parsing.combinator.RegexParsers @@ -134,16 +135,18 @@ case class GetJsonObject(json: Expression, path: Expression) if (parsed.isDefined) { try { - val parser = jsonFactory.createParser(jsonStr.getBytes) - val output = new ByteArrayOutputStream() - val generator = jsonFactory.createGenerator(output, JsonEncoding.UTF8) - parser.nextToken() - val matched = evaluatePath(parser, generator, RawStyle, parsed.get) - generator.close() - if (matched) { - UTF8String.fromBytes(output.toByteArray) - } else { - null + Utils.tryWithResource(jsonFactory.createParser(jsonStr.getBytes)) { parser => + val output = new ByteArrayOutputStream() + val matched = Utils.tryWithResource( + jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator => + parser.nextToken() + evaluatePath(parser, generator, RawStyle, parsed.get) + } + if (matched) { + UTF8String.fromBytes(output.toByteArray) + } else { + null + } } } catch { case _: JsonProcessingException => null @@ -250,17 +253,18 @@ case class GetJsonObject(json: Expression, path: Expression) // temporarily buffer child matches, the emitted json will need to be // modified slightly if there is only a single element written val buffer = new StringWriter() - val flattenGenerator = jsonFactory.createGenerator(buffer) - flattenGenerator.writeStartArray() var dirty = 0 - while (p.nextToken() != END_ARRAY) { - // track the number of array elements and only emit an outer array if - // we've written more than one element, this matches Hive's behavior - dirty += (if (evaluatePath(p, flattenGenerator, nextStyle, xs)) 1 else 0) + Utils.tryWithResource(jsonFactory.createGenerator(buffer)) { flattenGenerator => + flattenGenerator.writeStartArray() + + while (p.nextToken() != END_ARRAY) { + // track the number of array elements and only emit an outer array if + // we've written more than one element, this matches Hive's behavior + dirty += (if (evaluatePath(p, flattenGenerator, nextStyle, xs)) 1 else 0) + } + flattenGenerator.writeEndArray() } - flattenGenerator.writeEndArray() - flattenGenerator.close() val buf = buffer.getBuffer if (dirty > 1) { @@ -370,12 +374,8 @@ case class JsonTuple(children: Seq[Expression]) } try { - val parser = jsonFactory.createParser(json.getBytes) - - try { - parseRow(parser, input) - } finally { - parser.close() + Utils.tryWithResource(jsonFactory.createParser(json.getBytes)) { + parser => parseRow(parser, input) } } catch { case _: JsonProcessingException => @@ -420,12 +420,8 @@ case class JsonTuple(children: Seq[Expression]) // write the output directly to UTF8 encoded byte array if (parser.nextToken() != JsonToken.VALUE_NULL) { - val generator = jsonFactory.createGenerator(output, JsonEncoding.UTF8) - - try { - copyCurrentStructure(generator, parser) - } finally { - generator.close() + Utils.tryWithResource(jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { + generator => copyCurrentStructure(generator, parser) } row(idx) = UTF8String.fromBytes(output.toByteArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index b6f3410bad69016b8662970b75e4f2d92bb49f11..d0780028dacb10e803e86811b2b427109bbdd591 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils private[sql] object InferSchema { /** @@ -47,9 +48,10 @@ private[sql] object InferSchema { val factory = new JsonFactory() iter.map { row => try { - val parser = factory.createParser(row) - parser.nextToken() - inferField(parser) + Utils.tryWithResource(factory.createParser(row)) { parser => + parser.nextToken() + inferField(parser) + } } catch { case _: JsonParseException => StructType(Seq(StructField(columnNameOfCorruptRecords, StringType))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index c51140749c8e6006a731c40ed561c6258451ad51..09b8a9e936a1d630269f51193178a280f1a88e2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils private[sql] object JacksonParser { def apply( @@ -86,9 +87,9 @@ private[sql] object JacksonParser { case (_, StringType) => val writer = new ByteArrayOutputStream() - val generator = factory.createGenerator(writer, JsonEncoding.UTF8) - generator.copyCurrentStructure(parser) - generator.close() + Utils.tryWithResource(factory.createGenerator(writer, JsonEncoding.UTF8)) { + generator => generator.copyCurrentStructure(parser) + } UTF8String.fromBytes(writer.toByteArray) case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT, FloatType) => @@ -245,22 +246,24 @@ private[sql] object JacksonParser { iter.flatMap { record => try { - val parser = factory.createParser(record) - parser.nextToken() - - convertField(factory, parser, schema) match { - case null => failedRecord(record) - case row: InternalRow => row :: Nil - case array: ArrayData => - if (array.numElements() == 0) { - Nil - } else { - array.toArray[InternalRow](schema) - } - case _ => - sys.error( - s"Failed to parse record $record. Please make sure that each line of the file " + - "(or each string in the RDD) is a valid JSON object or an array of JSON objects.") + Utils.tryWithResource(factory.createParser(record)) { parser => + parser.nextToken() + + convertField(factory, parser, schema) match { + case null => failedRecord(record) + case row: InternalRow => row :: Nil + case array: ArrayData => + if (array.numElements() == 0) { + Nil + } else { + array.toArray[InternalRow](schema) + } + case _ => + sys.error( + s"Failed to parse record $record. Please make sure that each line of the file " + + "(or each string in the RDD) is a valid JSON object or " + + "an array of JSON objects.") + } } } catch { case _: JsonProcessingException =>