diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index e5d62a466cab6909da7ac6ddf4571bcb1fd91236..abb284d1e3dd9470277365ca07e7423aa0f55149 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -45,7 +45,7 @@ from py4j.java_collections import ListConverter, MapConverter
 
 from pyspark.rdd import RDD
 from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
-    CloudPickleSerializer
+    CloudPickleSerializer, UTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
 
@@ -1870,6 +1870,21 @@ class SchemaRDD(RDD):
         rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD()
         return SchemaRDD(rdd, self.sql_ctx)
 
+    def toJSON(self, use_unicode=False):
+        """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row.
+
+        >>> srdd1 = sqlCtx.jsonRDD(json)
+        >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
+        >>> srdd2 = sqlCtx.sql( "SELECT * from table1")
+        >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
+        True
+        >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1")
+        >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
+        True
+        """
+        rdd = self._jschema_rdd.baseSchemaRDD().toJSON()
+        return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
+
     def saveAsParquetFile(self, path):
         """Save the contents as a Parquet file, preserving the schema.
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 904a276ef3ffba8ae28cac4381ee80f2ef4e95f1..f8970cd3e6360a6d5c9bd2e20e281069a7027dcf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -17,17 +17,19 @@
 
 package org.apache.spark.sql
 
-import java.util.{List => JList}
-
-import org.apache.spark.api.python.SerDeUtil
+import java.util.{Map => JMap, List => JList}
+import java.io.StringWriter
 
 import scala.collection.JavaConversions._
 
+import com.fasterxml.jackson.core.JsonFactory
+
 import net.razorvine.pickle.Pickler
 
 import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext}
 import org.apache.spark.annotation.{AlphaComponent, Experimental}
 import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.api.java.JavaSchemaRDD
 import org.apache.spark.sql.catalyst.ScalaReflection
@@ -35,6 +37,7 @@ import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.json.JsonRDD
 import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
 import org.apache.spark.storage.StorageLevel
 
@@ -131,6 +134,20 @@ class SchemaRDD(
    */
   lazy val schema: StructType = queryExecution.analyzed.schema
 
+  /**
+   * Returns a new RDD with each row transformed to a JSON string.
+   *
+   * @group schema
+   */
+  def toJSON: RDD[String] = {
+    val rowSchema = this.schema
+    this.mapPartitions { iter =>
+      val jsonFactory = new JsonFactory()
+      iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory))
+    }
+  }
+
+
   // =======================================================================
   // Query DSL
   // =======================================================================
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
index 78e8d908fe0c8ef2bc0a558113f6283c0ee15d83..ac4844f9b929035e322c17b1d0d09433b49513dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala
@@ -125,6 +125,12 @@ class JavaSchemaRDD(
 
   // Transformations (return a new RDD)
 
+  /**
+   * Returns a new RDD with each row transformed to a JSON string.
+   */
+  def toJSON(): JavaRDD[String] =
+    baseSchemaRDD.toJSON.toJavaRDD
+
   /**
    * Return a new RDD that is reduced into `numPartitions` partitions.
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index d9d7a3fea3963718964e932c273310936097fad6..ffb9548356d1dcb377ad616c4a1093fc189697f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -20,12 +20,15 @@ package org.apache.spark.sql.json
 import org.apache.spark.sql.catalyst.types.decimal.Decimal
 import org.apache.spark.sql.types.util.DataTypeConversions
 
+import java.io.StringWriter
+
 import scala.collection.Map
 import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
 import scala.math.BigDecimal
 import java.sql.{Date, Timestamp}
 
 import com.fasterxml.jackson.core.JsonProcessingException
+import com.fasterxml.jackson.core.JsonFactory
 import com.fasterxml.jackson.databind.ObjectMapper
 
 import org.apache.spark.rdd.RDD
@@ -424,4 +427,61 @@ private[sql] object JsonRDD extends Logging {
 
     row
   }
+
+  /** Transforms a single Row to JSON using Jackson
+    *
+    * @param jsonFactory a JsonFactory object to construct a JsonGenerator
+    * @param rowSchema the schema object used for conversion
+    * @param row The row to convert
+    */
+  private[sql] def rowToJSON(rowSchema: StructType, jsonFactory: JsonFactory)(row: Row): String = {
+    val writer = new StringWriter()
+    val gen = jsonFactory.createGenerator(writer)
+
+    def valWriter: (DataType, Any) => Unit = {
+      case (_, null) | (NullType, _)  => gen.writeNull()
+      case (StringType, v: String) => gen.writeString(v)
+      case (TimestampType, v: java.sql.Timestamp) => gen.writeString(v.toString)
+      case (IntegerType, v: Int) => gen.writeNumber(v)
+      case (ShortType, v: Short) => gen.writeNumber(v)
+      case (FloatType, v: Float) => gen.writeNumber(v)
+      case (DoubleType, v: Double) => gen.writeNumber(v)
+      case (LongType, v: Long) => gen.writeNumber(v)
+      case (DecimalType(), v: scala.math.BigDecimal) => gen.writeNumber(v.bigDecimal)
+      case (DecimalType(), v: java.math.BigDecimal) => gen.writeNumber(v)
+      case (ByteType, v: Byte) => gen.writeNumber(v.toInt)
+      case (BinaryType, v: Array[Byte]) => gen.writeBinary(v)
+      case (BooleanType, v: Boolean) => gen.writeBoolean(v)
+      case (DateType, v) => gen.writeString(v.toString)
+      case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v)
+
+      case (ArrayType(ty, _), v: Seq[_] ) =>
+        gen.writeStartArray()
+        v.foreach(valWriter(ty,_))
+        gen.writeEndArray()
+
+      case (MapType(kv,vv, _), v: Map[_,_]) =>
+        gen.writeStartObject
+        v.foreach { p =>
+          gen.writeFieldName(p._1.toString)
+          valWriter(vv,p._2)
+        }
+        gen.writeEndObject
+
+      case (StructType(ty), v: Seq[_]) =>
+        gen.writeStartObject()
+        ty.zip(v).foreach {
+          case (_, null) =>
+          case (field, v) =>
+            gen.writeFieldName(field.name)
+            valWriter(field.dataType, v)
+        }
+        gen.writeEndObject()
+    }
+
+    valWriter(rowSchema, row)
+    gen.close()
+    writer.toString
+  }
+
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index f8ca2c773d9ab527a0fa70fff43087f8db8a86d3..f088d413257a929cc36246f57eaea641df693e9a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
 import org.apache.spark.sql.{Row, SQLConf, QueryTest}
+import org.apache.spark.sql.TestData._
 import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.test.TestSQLContext._
 
@@ -779,4 +780,125 @@ class JsonSuite extends QueryTest {
       Seq(null, null, null, Seq(Seq(null, Seq(1, 2, 3)))) :: Nil
     )
   }
+
+  test("SPARK-4228 SchemaRDD to JSON")
+  {
+    val schema1 = StructType(
+      StructField("f1", IntegerType, false) ::
+      StructField("f2", StringType, false) ::
+      StructField("f3", BooleanType, false) ::
+      StructField("f4", ArrayType(StringType), nullable = true) ::
+      StructField("f5", IntegerType, true) :: Nil)
+
+    val rowRDD1 = unparsedStrings.map { r =>
+      val values = r.split(",").map(_.trim)
+      val v5 = try values(3).toInt catch {
+        case _: NumberFormatException => null
+      }
+      Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
+    }
+
+    val schemaRDD1 = applySchema(rowRDD1, schema1)
+    schemaRDD1.registerTempTable("applySchema1")
+    val schemaRDD2 = schemaRDD1.toSchemaRDD
+    val result = schemaRDD2.toJSON.collect()
+    assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}")
+    assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}")
+
+    val schema2 = StructType(
+      StructField("f1", StructType(
+        StructField("f11", IntegerType, false) ::
+        StructField("f12", BooleanType, false) :: Nil), false) ::
+      StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil)
+
+    val rowRDD2 = unparsedStrings.map { r =>
+      val values = r.split(",").map(_.trim)
+      val v4 = try values(3).toInt catch {
+        case _: NumberFormatException => null
+      }
+      Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
+    }
+
+    val schemaRDD3 = applySchema(rowRDD2, schema2)
+    schemaRDD3.registerTempTable("applySchema2")
+    val schemaRDD4 = schemaRDD3.toSchemaRDD
+    val result2 = schemaRDD4.toJSON.collect()
+
+    assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
+    assert(result2(3) == "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}")
+
+    val jsonSchemaRDD = jsonRDD(primitiveFieldAndType)
+    val primTable = jsonRDD(jsonSchemaRDD.toJSON)
+    primTable.registerTempTable("primativeTable")
+    checkAnswer(
+        sql("select * from primativeTable"),
+        (BigDecimal("92233720368547758070"),
+        true,
+        1.7976931348623157E308,
+        10,
+        21474836470L,
+        "this is a simple string.") :: Nil
+      )
+
+    val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1)
+    val compTable = jsonRDD(complexJsonSchemaRDD.toJSON)
+    compTable.registerTempTable("complexTable")
+    // Access elements of a primitive array.
+    checkAnswer(
+      sql("select arrayOfString[0], arrayOfString[1], arrayOfString[2] from complexTable"),
+      ("str1", "str2", null) :: Nil
+    )
+
+    // Access an array of null values.
+    checkAnswer(
+      sql("select arrayOfNull from complexTable"),
+      Seq(Seq(null, null, null, null)) :: Nil
+    )
+
+    // Access elements of a BigInteger array (we use DecimalType internally).
+    checkAnswer(
+      sql("select arrayOfBigInteger[0], arrayOfBigInteger[1], arrayOfBigInteger[2] from complexTable"),
+      (BigDecimal("922337203685477580700"), BigDecimal("-922337203685477580800"), null) :: Nil
+    )
+
+    // Access elements of an array of arrays.
+    checkAnswer(
+      sql("select arrayOfArray1[0], arrayOfArray1[1] from complexTable"),
+      (Seq("1", "2", "3"), Seq("str1", "str2")) :: Nil
+    )
+
+    // Access elements of an array of arrays.
+    checkAnswer(
+      sql("select arrayOfArray2[0], arrayOfArray2[1] from complexTable"),
+      (Seq(1.0, 2.0, 3.0), Seq(1.1, 2.1, 3.1)) :: Nil
+    )
+
+    // Access elements of an array inside a filed with the type of ArrayType(ArrayType).
+    checkAnswer(
+      sql("select arrayOfArray1[1][1], arrayOfArray2[1][1] from complexTable"),
+      ("str2", 2.1) :: Nil
+    )
+
+    // Access a struct and fields inside of it.
+    checkAnswer(
+      sql("select struct, struct.field1, struct.field2 from complexTable"),
+      Row(
+        Row(true, BigDecimal("92233720368547758070")),
+        true,
+        BigDecimal("92233720368547758070")) :: Nil
+    )
+
+    // Access an array field of a struct.
+    checkAnswer(
+      sql("select structWithArrayFields.field1, structWithArrayFields.field2 from complexTable"),
+      (Seq(4, 5, 6), Seq("str1", "str2")) :: Nil
+    )
+
+    // Access elements of an array field of a struct.
+    checkAnswer(
+      sql("select structWithArrayFields.field1[1], structWithArrayFields.field2[3] from complexTable"),
+      (5, null) :: Nil
+    )
+
+  }
 }