diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 1e9b3bb5c03480a2e6e10c8df85fa8087f44c76e..8394fe6a319358f58a7ee7cac430204743c0fbc0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -21,6 +21,7 @@ from collections import defaultdict
 from itertools import chain, ifilter, imap, product
 import operator
 import os
+import sys
 import shlex
 from subprocess import Popen, PIPE
 from tempfile import NamedTemporaryFile
@@ -32,6 +33,7 @@ from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
 from pyspark.join import python_join, python_left_outer_join, \
     python_right_outer_join, python_cogroup
 from pyspark.statcounter import StatCounter
+from pyspark.rddsampler import RDDSampler
 
 from py4j.java_collections import ListConverter, MapConverter
 
@@ -165,14 +167,60 @@ class RDD(object):
                    .reduceByKey(lambda x, _: x) \
                    .map(lambda (x, _): x)
 
-    # TODO: sampling needs to be re-implemented due to Batch
-    #def sample(self, withReplacement, fraction, seed):
-    #    jrdd = self._jrdd.sample(withReplacement, fraction, seed)
-    #    return RDD(jrdd, self.ctx)
+    def sample(self, withReplacement, fraction, seed):
+        """
+        Return a sampled subset of this RDD (relies on numpy and falls back
+        on default random generator if numpy is unavailable).
+
+        >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
+        [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
+        """
+        return self.mapPartitionsWithSplit(RDDSampler(withReplacement, fraction, seed).func, True)
+
+    # this is ported from scala/spark/RDD.scala
+    def takeSample(self, withReplacement, num, seed):
+        """
+        Return a fixed-size sampled subset of this RDD (currently requires numpy).
+
+        >>> sc.parallelize(range(0, 10)).takeSample(True, 10, 1) #doctest: +SKIP
+        [4, 2, 1, 8, 2, 7, 0, 4, 1, 4]
+        """
 
-    #def takeSample(self, withReplacement, num, seed):
-    #    vals = self._jrdd.takeSample(withReplacement, num, seed)
-    #    return [load_pickle(bytes(x)) for x in vals]
+        fraction = 0.0
+        total = 0
+        multiplier = 3.0
+        initialCount = self.count()
+        maxSelected = 0
+
+        if (num < 0):
+            raise ValueError
+
+        if initialCount > sys.maxint - 1:
+            maxSelected = sys.maxint - 1
+        else:
+            maxSelected = initialCount
+
+        if num > initialCount and not withReplacement:
+            total = maxSelected
+            fraction = multiplier * (maxSelected + 1) / initialCount
+        else:
+            fraction = multiplier * (num + 1) / initialCount
+            total = num
+
+        samples = self.sample(withReplacement, fraction, seed).collect()
+    
+        # If the first sample didn't turn out large enough, keep trying to take samples;
+        # this shouldn't happen often because we use a big multiplier for their initial size.
+        # See: scala/spark/RDD.scala
+        while len(samples) < total:
+            if seed > sys.maxint - 2:
+                seed = -1
+            seed += 1
+            samples = self.sample(withReplacement, fraction, seed).collect()
+
+        sampler = RDDSampler(withReplacement, fraction, seed+1)
+        sampler.shuffle(samples)
+        return samples[0:total]
 
     def union(self, other):
         """
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..aca2ef3b51e98239581adf45492982ae2ae8adc3
--- /dev/null
+++ b/python/pyspark/rddsampler.py
@@ -0,0 +1,112 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+import random
+
+class RDDSampler(object):
+    def __init__(self, withReplacement, fraction, seed):
+        try:
+            import numpy
+            self._use_numpy = True
+        except ImportError:
+            print >> sys.stderr, "NumPy does not appear to be installed. Falling back to default random generator for sampling."
+            self._use_numpy = False
+
+        self._seed = seed
+        self._withReplacement = withReplacement
+        self._fraction = fraction
+        self._random = None
+        self._split = None
+        self._rand_initialized = False
+
+    def initRandomGenerator(self, split):
+        if self._use_numpy:
+            import numpy
+            self._random = numpy.random.RandomState(self._seed)
+            for _ in range(0, split):
+                # discard the next few values in the sequence to have a
+                # different seed for the different splits
+                self._random.randint(sys.maxint)
+        else:
+            import random
+            random.seed(self._seed)
+            for _ in range(0, split):
+                # discard the next few values in the sequence to have a
+                # different seed for the different splits
+                random.randint(0, sys.maxint)
+        self._split = split
+        self._rand_initialized = True
+
+    def getUniformSample(self, split):
+        if not self._rand_initialized or split != self._split:
+            self.initRandomGenerator(split)
+
+        if self._use_numpy:
+            return self._random.random_sample()
+        else:
+            return random.uniform(0.0, 1.0)
+
+    def getPoissonSample(self, split, mean):
+        if not self._rand_initialized or split != self._split:
+            self.initRandomGenerator(split)
+        
+        if self._use_numpy:
+            return self._random.poisson(mean)
+        else:
+            # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by
+            # drawing a sequence of numbers delta_j ~ Exp(mean)
+            num_arrivals = 1
+            cur_time = 0.0
+
+            cur_time += random.expovariate(mean)
+
+            if cur_time > 1.0:
+                return 0
+
+            while(cur_time <= 1.0):
+                cur_time += random.expovariate(mean)
+                num_arrivals += 1
+
+            return (num_arrivals - 1)
+    
+    def shuffle(self, vals):
+        if self._random == None or split != self._split:
+            self.initRandomGenerator(0)  # this should only ever called on the master so
+            # the split does not matter
+        
+        if self._use_numpy:
+            self._random.shuffle(vals)
+        else:
+            random.shuffle(vals, self._random)
+
+    def func(self, split, iterator):
+        if self._withReplacement:            
+            for obj in iterator:
+                # For large datasets, the expected number of occurrences of each element in a sample with
+                # replacement is Poisson(frac). We use that to get a count for each element.                                   
+                count = self.getPoissonSample(split, mean = self._fraction)
+                for _ in range(0, count):
+                    yield obj
+        else:
+            for obj in iterator:
+                if self.getUniformSample(split) <= self._fraction:
+                    yield obj
+
+            
+            
+