diff --git a/python/pyspark/context.py b/python/pyspark/context.py index cac133d0fcf6c21ba4bf567c0a0e3e4533ce6b32..c9ff82d23b3cfb0eeee177c2ece5a583903801d3 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -211,6 +211,13 @@ class SparkContext(object): """ return self._jsc.sc().defaultParallelism() + @property + def defaultMinPartitions(self): + """ + Default min number of partitions for Hadoop RDDs when not given by user + """ + return self._jsc.sc().defaultMinPartitions() + def __del__(self): self.stop() @@ -264,7 +271,7 @@ class SparkContext(object): return RDD(self._jsc.textFile(name, minPartitions), self, UTF8Deserializer()) - def wholeTextFiles(self, path): + def wholeTextFiles(self, path, minPartitions=None): """ Read a directory of text files from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system @@ -300,7 +307,8 @@ class SparkContext(object): >>> sorted(textFiles.collect()) [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')] """ - return RDD(self._jsc.wholeTextFiles(path), self, + minPartitions = minPartitions or self.defaultMinPartitions + return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, PairDeserializer(UTF8Deserializer(), UTF8Deserializer())) def _checkpointFile(self, name, input_deserializer):