Skip to content
Snippets Groups Projects
Commit 95db8a44 authored by Takeshi YAMAMURO's avatar Takeshi YAMAMURO Committed by Sean Owen
Browse files

[SPARK-15528][SQL] Fix race condition in NumberConverter

## What changes were proposed in this pull request?
A local variable in NumberConverter is wrongly shared between threads.
This pr fixes the race condition.

## How was this patch tested?
Manually checked.

Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>

Closes #13391 from maropu/SPARK-15528.
parent 6878f3e2
No related branches found
No related tags found
No related merge requests found
......@@ -21,8 +21,6 @@ import org.apache.spark.unsafe.types.UTF8String
object NumberConverter {
private val value = new Array[Byte](64)
/**
* Divide x by m as if x is an unsigned 64-bit integer. Examples:
* unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2
......@@ -49,7 +47,7 @@ object NumberConverter {
* @param v is treated as an unsigned 64-bit integer
* @param radix must be between MIN_RADIX and MAX_RADIX
*/
private def decode(v: Long, radix: Int): Unit = {
private def decode(v: Long, radix: Int, value: Array[Byte]): Unit = {
var tmpV = v
java.util.Arrays.fill(value, 0.asInstanceOf[Byte])
var i = value.length - 1
......@@ -69,11 +67,9 @@ object NumberConverter {
* @param fromPos is the first element that should be considered
* @return the result should be treated as an unsigned 64-bit integer.
*/
private def encode(radix: Int, fromPos: Int): Long = {
private def encode(radix: Int, fromPos: Int, value: Array[Byte]): Long = {
var v: Long = 0L
val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once
// val
// exceeds this value
var i = fromPos
while (i < value.length && value(i) >= 0) {
if (v >= bound) {
......@@ -94,7 +90,7 @@ object NumberConverter {
* @param radix must be between MIN_RADIX and MAX_RADIX
* @param fromPos is the first nonzero element
*/
private def byte2char(radix: Int, fromPos: Int): Unit = {
private def byte2char(radix: Int, fromPos: Int, value: Array[Byte]): Unit = {
var i = fromPos
while (i < value.length) {
value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte]
......@@ -109,9 +105,9 @@ object NumberConverter {
* @param radix must be between MIN_RADIX and MAX_RADIX
* @param fromPos is the first nonzero element
*/
private def char2byte(radix: Int, fromPos: Int): Unit = {
private def char2byte(radix: Int, fromPos: Int, value: Array[Byte]): Unit = {
var i = fromPos
while ( i < value.length) {
while (i < value.length) {
value(i) = Character.digit(value(i), radix).asInstanceOf[Byte]
i += 1
}
......@@ -124,8 +120,8 @@ object NumberConverter {
*/
def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = {
if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
|| Math.abs(toBase) < Character.MIN_RADIX
|| Math.abs(toBase) > Character.MAX_RADIX) {
|| Math.abs(toBase) < Character.MIN_RADIX
|| Math.abs(toBase) > Character.MAX_RADIX) {
return null
}
......@@ -136,15 +132,16 @@ object NumberConverter {
var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0)
// Copy the digits in the right side of the array
val temp = new Array[Byte](64)
var i = 1
while (i <= n.length - first) {
value(value.length - i) = n(n.length - i)
temp(temp.length - i) = n(n.length - i)
i += 1
}
char2byte(fromBase, value.length - n.length + first)
char2byte(fromBase, temp.length - n.length + first, temp)
// Do the conversion by going through a 64 bit integer
var v = encode(fromBase, value.length - n.length + first)
var v = encode(fromBase, temp.length - n.length + first, temp)
if (negative && toBase > 0) {
if (v < 0) {
v = -1
......@@ -156,21 +153,20 @@ object NumberConverter {
v = -v
negative = true
}
decode(v, Math.abs(toBase))
decode(v, Math.abs(toBase), temp)
// Find the first non-zero digit or the last digits if all are zero.
val firstNonZeroPos = {
val firstNonZero = value.indexWhere( _ != 0)
if (firstNonZero != -1) firstNonZero else value.length - 1
val firstNonZero = temp.indexWhere( _ != 0)
if (firstNonZero != -1) firstNonZero else temp.length - 1
}
byte2char(Math.abs(toBase), firstNonZeroPos)
byte2char(Math.abs(toBase), firstNonZeroPos, temp)
var resultStartPos = firstNonZeroPos
if (negative && toBase < 0) {
resultStartPos = firstNonZeroPos - 1
value(resultStartPos) = '-'
temp(resultStartPos) = '-'
}
UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length))
UTF8String.fromBytes(java.util.Arrays.copyOfRange(temp, resultStartPos, temp.length))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment