diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1f3fab09e95665e55d7fdad2d9b5cebd399bd4c3..8b79b0cd65a84e099fea1dcd87112e7529aa33cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -111,7 +111,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case StringType => nullOrCast[String](_, s => try s.toLong catch { case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L) case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t)) case DecimalType => nullOrCast[BigDecimal](_, _.toLong) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) @@ -131,7 +131,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case StringType => nullOrCast[String](_, s => try s.toShort catch { case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 0.toShort) case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort) case DecimalType => nullOrCast[BigDecimal](_, _.toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort @@ -141,7 +141,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case StringType => nullOrCast[String](_, s => try s.toByte catch { case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 0.toByte) case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte) case DecimalType => nullOrCast[BigDecimal](_, _.toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte @@ -162,7 +162,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case StringType => nullOrCast[String](_, s => try s.toDouble catch { case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d) case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t)) case DecimalType => nullOrCast[BigDecimal](_, _.toDouble) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) @@ -172,7 +172,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case StringType => nullOrCast[String](_, s => try s.toFloat catch { case _: NumberFormatException => null }) - case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f) case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat) case DecimalType => nullOrCast[BigDecimal](_, _.toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 2cd0d2b0e13853e542c1144c3f7cff6bf3a742c2..4ce0dff9e15862f03811a29f319443cb71b6ba56 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -237,6 +237,13 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("2012-12-11" cast DoubleType, null) checkEvaluation(Literal(123) cast IntegerType, 123) + checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24) + checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) + checkEvaluation(Literal(23f) + Cast(true, FloatType), 24) + checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24) + checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24) + checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24) + intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} }