diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1f3fab09e95665e55d7fdad2d9b5cebd399bd4c3..8b79b0cd65a84e099fea1dcd87112e7529aa33cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -111,7 +111,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
     case StringType => nullOrCast[String](_, s => try s.toLong catch {
       case _: NumberFormatException => null
     })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1L else 0L)
     case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t))
     case DecimalType => nullOrCast[BigDecimal](_, _.toLong)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
@@ -131,7 +131,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
     case StringType => nullOrCast[String](_, s => try s.toShort catch {
       case _: NumberFormatException => null
     })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toShort else 0.toShort)
     case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort)
     case DecimalType => nullOrCast[BigDecimal](_, _.toShort)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
@@ -141,7 +141,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
     case StringType => nullOrCast[String](_, s => try s.toByte catch {
       case _: NumberFormatException => null
     })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1.toByte else 0.toByte)
     case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte)
     case DecimalType => nullOrCast[BigDecimal](_, _.toByte)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
@@ -162,7 +162,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
     case StringType => nullOrCast[String](_, s => try s.toDouble catch {
       case _: NumberFormatException => null
     })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1d else 0d)
     case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t))
     case DecimalType => nullOrCast[BigDecimal](_, _.toDouble)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
@@ -172,7 +172,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
     case StringType => nullOrCast[String](_, s => try s.toFloat catch {
       case _: NumberFormatException => null
     })
-    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0)
+    case BooleanType => nullOrCast[Boolean](_, b => if(b) 1f else 0f)
     case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat)
     case DecimalType => nullOrCast[BigDecimal](_, _.toFloat)
     case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 2cd0d2b0e13853e542c1144c3f7cff6bf3a742c2..4ce0dff9e15862f03811a29f319443cb71b6ba56 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -237,6 +237,13 @@ class ExpressionEvaluationSuite extends FunSuite {
     checkEvaluation("2012-12-11" cast DoubleType, null)
     checkEvaluation(Literal(123) cast IntegerType, 123)
 
+    checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24)
+    checkEvaluation(Literal(23) + Cast(true, IntegerType), 24)
+    checkEvaluation(Literal(23f) + Cast(true, FloatType), 24)
+    checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24)
+    checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24)
+    checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24)
+
     intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)}
   }