diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index ab49ed4b5d4ecfce47f129127c6545aaee20bab0..b6330e230afef80c685d8e60f0fcd7d0fa69a13f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -83,9 +83,6 @@ object functions extends LegacyFunctions {
     Column(func.toAggregateExpression(isDistinct))
   }
 
-  private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
-
-
   /**
    * Returns a [[Column]] based on the given column name.
    *
@@ -269,7 +266,8 @@ object functions extends LegacyFunctions {
    * @group agg_funcs
    * @since 1.3.0
    */
-  def count(columnName: String): TypedColumn[Any, Long] = count(Column(columnName)).as[Long]
+  def count(columnName: String): TypedColumn[Any, Long] =
+    count(Column(columnName)).as(ExpressionEncoder[Long](flat = true))
 
   /**
    * Aggregate function: returns the number of distinct items in a group.