From d053a31be93d789e3f26cf55d747ecf6ca386c29 Mon Sep 17 00:00:00 2001
From: animesh <animesh@apache.spark>
Date: Wed, 3 Jun 2015 11:28:18 -0700
Subject: [PATCH] [SPARK-7980] [SQL] Support SQLContext.range(end)

1. range() overloaded in SQLContext.scala
2. range() modified in python sql context.py
3. Tests added accordingly in DataFrameSuite.scala and python sql tests.py

Author: animesh <animesh@apache.spark>

Closes #6609 from animeshbaranawal/SPARK-7980 and squashes the following commits:

935899c [animesh] SPARK-7980:python+scala changes
---
 python/pyspark/sql/context.py                        | 12 ++++++++++--
 python/pyspark/sql/tests.py                          |  2 ++
 .../main/scala/org/apache/spark/sql/SQLContext.scala | 11 +++++++++++
 .../scala/org/apache/spark/sql/DataFrameSuite.scala  |  8 ++++++++
 4 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 9fdf43c3e6..1bebfc4837 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -131,7 +131,7 @@ class SQLContext(object):
         return UDFRegistration(self)
 
     @since(1.4)
-    def range(self, start, end, step=1, numPartitions=None):
+    def range(self, start, end=None, step=1, numPartitions=None):
         """
         Create a :class:`DataFrame` with single LongType column named `id`,
         containing elements in a range from `start` to `end` (exclusive) with
@@ -145,10 +145,18 @@ class SQLContext(object):
 
         >>> sqlContext.range(1, 7, 2).collect()
         [Row(id=1), Row(id=3), Row(id=5)]
+
+        >>> sqlContext.range(3).collect()
+        [Row(id=0), Row(id=1), Row(id=2)]
         """
         if numPartitions is None:
             numPartitions = self._sc.defaultParallelism
-        jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
+
+        if end is None:
+            jdf = self._ssql_ctx.range(0, int(start), int(step), int(numPartitions))
+        else:
+            jdf = self._ssql_ctx.range(int(start), int(end), int(step), int(numPartitions))
+
         return DataFrame(jdf, self)
 
     @ignore_unicode_prefix
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 6e498f0af0..a6fce50c76 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -131,6 +131,8 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(self.sqlCtx.range(1, 1).count(), 0)
         self.assertEqual(self.sqlCtx.range(1, 0, -1).count(), 1)
         self.assertEqual(self.sqlCtx.range(0, 1 << 40, 1 << 39).count(), 2)
+        self.assertEqual(self.sqlCtx.range(-2).count(), 0)
+        self.assertEqual(self.sqlCtx.range(3).count(), 3)
 
     def test_explode(self):
         from pyspark.sql.functions import explode
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 91e6385dec..f08fb4fafe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -717,6 +717,17 @@ class SQLContext(@transient val sparkContext: SparkContext)
       StructType(StructField("id", LongType, nullable = false) :: Nil))
   }
 
+  /**
+   * :: Experimental ::
+   * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
+   * in an range from 0 to `end`(exclusive) with step value 1.
+   *
+   * @since 1.4.0
+   * @group dataframe
+   */
+  @Experimental
+  def range(end: Long): DataFrame = range(0, end)
+
   /**
    * :: Experimental ::
    * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index a4fd1058af..9aaec2b064 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -576,5 +576,13 @@ class DataFrameSuite extends QueryTest {
     val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
     assert(res9.count == 2)
     assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))
+
+    // only end provided as argument
+    val res10 = TestSQLContext.range(10).select("id")
+    assert(res10.count == 10)
+    assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))
+
+    val res11 = TestSQLContext.range(-1).select("id")
+    assert(res11.count == 0)
   }
 }
-- 
GitLab