diff --git a/docs/configuration.md b/docs/configuration.md
index 7c5b6d011cfd3e6968b5b0d3a1373c66b5a5c104..e4e4b8d516b75936095ff6d17d7bfe0355af45d8 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -311,6 +311,9 @@ Apart from these, the following properties are also available, and may be useful
     or it will be displayed before the driver exiting. It also can be dumped into disk by
     `sc.dump_profiles(path)`. If some of the profile results had been displayed maually,
     they will not be displayed automatically before driver exiting.
+
+    By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by
+    passing a profiler class in as a parameter to the `SparkContext` constructor.
   </td>
 </tr>
 <tr>
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 9556e4718e5857dc41bd7688a1c9e9fbfa525db0..d3efcdf221d82d5189d67d2f09f31516f0cb9e11 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -45,6 +45,7 @@ from pyspark.storagelevel import StorageLevel
 from pyspark.accumulators import Accumulator, AccumulatorParam
 from pyspark.broadcast import Broadcast
 from pyspark.serializers import MarshalSerializer, PickleSerializer
+from pyspark.profiler import Profiler, BasicProfiler
 
 # for back compatibility
 from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
@@ -52,4 +53,5 @@ from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
 __all__ = [
     "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
     "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
+    "Profiler", "BasicProfiler",
 ]
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index b8cdbbe3cf2b6421e39bfaba95d55fee4252b7ba..ccbca67656c8db95503ebbe2072a722b182feeaa 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -215,21 +215,6 @@ FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0)
 COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
 
 
-class PStatsParam(AccumulatorParam):
-    """PStatsParam is used to merge pstats.Stats"""
-
-    @staticmethod
-    def zero(value):
-        return None
-
-    @staticmethod
-    def addInPlace(value1, value2):
-        if value1 is None:
-            return value2
-        value1.add(value2)
-        return value1
-
-
 class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
 
     """
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 568e21f3803bf608ce65cd586c3d1f93e2595d38..c0dec16ac1b25e3c111671df27eb0e07891df7cb 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -20,7 +20,6 @@ import shutil
 import sys
 from threading import Lock
 from tempfile import NamedTemporaryFile
-import atexit
 
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
@@ -33,6 +32,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deseria
 from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD
 from pyspark.traceback_utils import CallSite, first_spark_call
+from pyspark.profiler import ProfilerCollector, BasicProfiler
 
 from py4j.java_collections import ListConverter
 
@@ -66,7 +66,7 @@ class SparkContext(object):
 
     def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
                  environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
-                 gateway=None, jsc=None):
+                 gateway=None, jsc=None, profiler_cls=BasicProfiler):
         """
         Create a new SparkContext. At least the master and app name should be set,
         either through the named parameters here or through C{conf}.
@@ -88,6 +88,9 @@ class SparkContext(object):
         :param conf: A L{SparkConf} object setting Spark properties.
         :param gateway: Use an existing gateway and JVM, otherwise a new JVM
                will be instantiated.
+        :param jsc: The JavaSparkContext instance (optional).
+        :param profiler_cls: A class of custom Profiler used to do profiling
+               (default is pyspark.profiler.BasicProfiler).
 
 
         >>> from pyspark.context import SparkContext
@@ -102,14 +105,14 @@ class SparkContext(object):
         SparkContext._ensure_initialized(self, gateway=gateway)
         try:
             self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
-                          conf, jsc)
+                          conf, jsc, profiler_cls)
         except:
             # If an error occurs, clean up in order to allow future SparkContext creation:
             self.stop()
             raise
 
     def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
-                 conf, jsc):
+                 conf, jsc, profiler_cls):
         self.environment = environment or {}
         self._conf = conf or SparkConf(_jvm=self._jvm)
         self._batchSize = batchSize  # -1 represents an unlimited batch size
@@ -192,7 +195,11 @@ class SparkContext(object):
             self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
 
         # profiling stats collected for each PythonRDD
-        self._profile_stats = []
+        if self._conf.get("spark.python.profile", "false") == "true":
+            dump_path = self._conf.get("spark.python.profile.dump", None)
+            self.profiler_collector = ProfilerCollector(profiler_cls, dump_path)
+        else:
+            self.profiler_collector = None
 
     def _initialize_context(self, jconf):
         """
@@ -826,39 +833,14 @@ class SparkContext(object):
         it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
         return list(mappedRDD._collect_iterator_through_file(it))
 
-    def _add_profile(self, id, profileAcc):
-        if not self._profile_stats:
-            dump_path = self._conf.get("spark.python.profile.dump")
-            if dump_path:
-                atexit.register(self.dump_profiles, dump_path)
-            else:
-                atexit.register(self.show_profiles)
-
-        self._profile_stats.append([id, profileAcc, False])
-
     def show_profiles(self):
         """ Print the profile stats to stdout """
-        for i, (id, acc, showed) in enumerate(self._profile_stats):
-            stats = acc.value
-            if not showed and stats:
-                print "=" * 60
-                print "Profile of RDD<id=%d>" % id
-                print "=" * 60
-                stats.sort_stats("time", "cumulative").print_stats()
-                # mark it as showed
-                self._profile_stats[i][2] = True
+        self.profiler_collector.show_profiles()
 
     def dump_profiles(self, path):
         """ Dump the profile stats into directory `path`
         """
-        if not os.path.exists(path):
-            os.makedirs(path)
-        for id, acc, _ in self._profile_stats:
-            stats = acc.value
-            if stats:
-                p = os.path.join(path, "rdd_%d.pstats" % id)
-                stats.dump_stats(p)
-        self._profile_stats = []
+        self.profiler_collector.dump_profiles(path)
 
 
 def _test():
diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..4408996db0790b1bda4aaa282d5323763a449c68
--- /dev/null
+++ b/python/pyspark/profiler.py
@@ -0,0 +1,172 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import cProfile
+import pstats
+import os
+import atexit
+
+from pyspark.accumulators import AccumulatorParam
+
+
+class ProfilerCollector(object):
+    """
+    This class keeps track of different profilers on a per
+    stage basis. Also this is used to create new profilers for
+    the different stages.
+    """
+
+    def __init__(self, profiler_cls, dump_path=None):
+        self.profiler_cls = profiler_cls
+        self.profile_dump_path = dump_path
+        self.profilers = []
+
+    def new_profiler(self, ctx):
+        """ Create a new profiler using class `profiler_cls` """
+        return self.profiler_cls(ctx)
+
+    def add_profiler(self, id, profiler):
+        """ Add a profiler for RDD `id` """
+        if not self.profilers:
+            if self.profile_dump_path:
+                atexit.register(self.dump_profiles, self.profile_dump_path)
+            else:
+                atexit.register(self.show_profiles)
+
+        self.profilers.append([id, profiler, False])
+
+    def dump_profiles(self, path):
+        """ Dump the profile stats into directory `path` """
+        for id, profiler, _ in self.profilers:
+            profiler.dump(id, path)
+        self.profilers = []
+
+    def show_profiles(self):
+        """ Print the profile stats to stdout """
+        for i, (id, profiler, showed) in enumerate(self.profilers):
+            if not showed and profiler:
+                profiler.show(id)
+                # mark it as showed
+                self.profilers[i][2] = True
+
+
+class Profiler(object):
+    """
+    .. note:: DeveloperApi
+
+    PySpark supports custom profilers, this is to allow for different profilers to
+    be used as well as outputting to different formats than what is provided in the
+    BasicProfiler.
+
+    A custom profiler has to define or inherit the following methods:
+        profile - will produce a system profile of some sort.
+        stats - return the collected stats.
+        dump - dumps the profiles to a path
+        add - adds a profile to the existing accumulated profile
+
+    The profiler class is chosen when creating a SparkContext
+
+    >>> from pyspark import SparkConf, SparkContext
+    >>> from pyspark import BasicProfiler
+    >>> class MyCustomProfiler(BasicProfiler):
+    ...     def show(self, id):
+    ...         print "My custom profiles for RDD:%s" % id
+    ...
+    >>> conf = SparkConf().set("spark.python.profile", "true")
+    >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
+    >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+    [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+    >>> sc.show_profiles()
+    My custom profiles for RDD:1
+    My custom profiles for RDD:2
+    >>> sc.stop()
+    """
+
+    def __init__(self, ctx):
+        pass
+
+    def profile(self, func):
+        """ Do profiling on the function `func`"""
+        raise NotImplemented
+
+    def stats(self):
+        """ Return the collected profiling stats (pstats.Stats)"""
+        raise NotImplemented
+
+    def show(self, id):
+        """ Print the profile stats to stdout, id is the RDD id """
+        stats = self.stats()
+        if stats:
+            print "=" * 60
+            print "Profile of RDD<id=%d>" % id
+            print "=" * 60
+            stats.sort_stats("time", "cumulative").print_stats()
+
+    def dump(self, id, path):
+        """ Dump the profile into path, id is the RDD id """
+        if not os.path.exists(path):
+            os.makedirs(path)
+        stats = self.stats()
+        if stats:
+            p = os.path.join(path, "rdd_%d.pstats" % id)
+            stats.dump_stats(p)
+
+
+class PStatsParam(AccumulatorParam):
+    """PStatsParam is used to merge pstats.Stats"""
+
+    @staticmethod
+    def zero(value):
+        return None
+
+    @staticmethod
+    def addInPlace(value1, value2):
+        if value1 is None:
+            return value2
+        value1.add(value2)
+        return value1
+
+
+class BasicProfiler(Profiler):
+    """
+    BasicProfiler is the default profiler, which is implemented based on
+    cProfile and Accumulator
+    """
+    def __init__(self, ctx):
+        Profiler.__init__(self, ctx)
+        # Creates a new accumulator for combining the profiles of different
+        # partitions of a stage
+        self._accumulator = ctx.accumulator(None, PStatsParam)
+
+    def profile(self, func):
+        """ Runs and profiles the method to_profile passed in. A profile object is returned. """
+        pr = cProfile.Profile()
+        pr.runcall(func)
+        st = pstats.Stats(pr)
+        st.stream = None  # make it picklable
+        st.strip_dirs()
+
+        # Adds a new profile to the existing accumulated value
+        self._accumulator.add(st)
+
+    def stats(self):
+        return self._accumulator.value
+
+
+if __name__ == "__main__":
+    import doctest
+    doctest.testmod()
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 014c0aa889c01c6cf84b8f1a076cd6fd83f4175d..b6dd5a3bf028dae8dbefc649575aec72fb695b34 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -31,7 +31,6 @@ import bisect
 import random
 from math import sqrt, log, isinf, isnan
 
-from pyspark.accumulators import PStatsParam
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
     BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
     PickleSerializer, pack_long, AutoBatchedSerializer
@@ -2132,9 +2131,13 @@ class PipelinedRDD(RDD):
             return self._jrdd_val
         if self._bypass_serializer:
             self._jrdd_deserializer = NoOpSerializer()
-        enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
-        profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
-        command = (self.func, profileStats, self._prev_jrdd_deserializer,
+
+        if self.ctx.profiler_collector:
+            profiler = self.ctx.profiler_collector.new_profiler(self.ctx)
+        else:
+            profiler = None
+
+        command = (self.func, profiler, self._prev_jrdd_deserializer,
                    self._jrdd_deserializer)
         # the serialized command will be compressed by broadcast
         ser = CloudPickleSerializer()
@@ -2157,9 +2160,9 @@ class PipelinedRDD(RDD):
                                              broadcast_vars, self.ctx._javaAccumulator)
         self._jrdd_val = python_rdd.asJavaRDD()
 
-        if enable_profile:
+        if profiler:
             self._id = self._jrdd_val.id()
-            self.ctx._add_profile(self._id, profileStats)
+            self.ctx.profiler_collector.add_profiler(self._id, profiler)
         return self._jrdd_val
 
     def id(self):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index e694ffcff59e162cf2941ccdda02b80f0cb5a6f0..081a77fbb0be2093b9f4d00af201ca78405ac7bc 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, External
 from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
     UserDefinedType, DoubleType
 from pyspark import shuffle
+from pyspark.profiler import BasicProfiler
 
 _have_scipy = False
 _have_numpy = False
@@ -743,16 +744,12 @@ class ProfilerTests(PySparkTestCase):
         self.sc = SparkContext('local[4]', class_name, conf=conf)
 
     def test_profiler(self):
+        self.do_computation()
 
-        def heavy_foo(x):
-            for i in range(1 << 20):
-                x = 1
-        rdd = self.sc.parallelize(range(100))
-        rdd.foreach(heavy_foo)
-        profiles = self.sc._profile_stats
-        self.assertEqual(1, len(profiles))
-        id, acc, _ = profiles[0]
-        stats = acc.value
+        profilers = self.sc.profiler_collector.profilers
+        self.assertEqual(1, len(profilers))
+        id, profiler, _ = profilers[0]
+        stats = profiler.stats()
         self.assertTrue(stats is not None)
         width, stat_list = stats.get_print_list([])
         func_names = [func_name for fname, n, func_name in stat_list]
@@ -763,6 +760,31 @@ class ProfilerTests(PySparkTestCase):
         self.sc.dump_profiles(d)
         self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
 
+    def test_custom_profiler(self):
+        class TestCustomProfiler(BasicProfiler):
+            def show(self, id):
+                self.result = "Custom formatting"
+
+        self.sc.profiler_collector.profiler_cls = TestCustomProfiler
+
+        self.do_computation()
+
+        profilers = self.sc.profiler_collector.profilers
+        self.assertEqual(1, len(profilers))
+        _, profiler, _ = profilers[0]
+        self.assertTrue(isinstance(profiler, TestCustomProfiler))
+
+        self.sc.show_profiles()
+        self.assertEqual("Custom formatting", profiler.result)
+
+    def do_computation(self):
+        def heavy_foo(x):
+            for i in range(1 << 20):
+                x = 1
+
+        rdd = self.sc.parallelize(range(100))
+        rdd.foreach(heavy_foo)
+
 
 class ExamplePointUDT(UserDefinedType):
     """
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7e5343c973dc5b028dfc50c7f306a157fef28d0c..8a93c320ec5d342be490f2ba73fc5850bfd85c8f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,8 +23,6 @@ import sys
 import time
 import socket
 import traceback
-import cProfile
-import pstats
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
@@ -90,19 +88,15 @@ def main(infile, outfile):
         command = pickleSer._read_with_length(infile)
         if isinstance(command, Broadcast):
             command = pickleSer.loads(command.value)
-        (func, stats, deserializer, serializer) = command
+        (func, profiler, deserializer, serializer) = command
         init_time = time.time()
 
         def process():
             iterator = deserializer.load_stream(infile)
             serializer.dump_stream(func(split_index, iterator), outfile)
 
-        if stats:
-            p = cProfile.Profile()
-            p.runcall(process)
-            st = pstats.Stats(p)
-            st.stream = None  # make it picklable
-            stats.add(st.strip_dirs())
+        if profiler:
+            profiler.profile(process)
         else:
             process()
     except Exception:
diff --git a/python/run-tests b/python/run-tests
index 9ee19ed6e6b2654b003bab9b06e165842d64ed0a..53c34557d9af1ef41a3e672127f028713ac9f337 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -57,6 +57,7 @@ function run_core_tests() {
     PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
     PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
     run_test "pyspark/serializers.py"
+    run_test "pyspark/profiler.py" 
     run_test "pyspark/shuffle.py"
     run_test "pyspark/tests.py"
 }