Skip to content
Snippets Groups Projects
Commit d252b2d5 authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-12309][ML] Use sqlContext from MLlibTestSparkContext for spark.ml test suites

Use ```sqlContext``` from ```MLlibTestSparkContext``` rather than creating new one for spark.ml test suites. I have checked thoroughly and found there are four test cases need to update.

cc mengxr jkbradley

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #10279 from yanboliang/spark-12309.
parent 860dc7f2
No related branches found
No related tags found
No related merge requests found
......@@ -21,13 +21,11 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.Row
class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("MinMaxScaler fit basic case") {
val sqlContext = new SQLContext(sc)
val data = Array(
Vectors.dense(1, 0, Long.MinValue),
Vectors.dense(2, 0, 0),
......
......@@ -22,7 +22,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.{DataFrame, Row}
class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
......@@ -61,7 +61,6 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
Vectors.sparse(3, Seq())
)
val sqlContext = new SQLContext(sc)
dataFrame = sqlContext.createDataFrame(sc.parallelize(data, 2).map(NormalizerSuite.FeatureData))
normalizer = new Normalizer()
.setInputCol("features")
......
......@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.{DataFrame, Row}
class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
......@@ -54,8 +54,6 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De
}
test("Test vector slicer") {
val sqlContext = new SQLContext(sc)
val data = Array(
Vectors.sparse(5, Seq((0, -2.0), (1, 2.3))),
Vectors.dense(-2.0, 2.3, 0.0, 0.0, 1.0),
......
......@@ -42,7 +42,7 @@ private[ml] object TreeTests extends SparkFunSuite {
data: RDD[LabeledPoint],
categoricalFeatures: Map[Int, Int],
numClasses: Int): DataFrame = {
val sqlContext = new SQLContext(data.sparkContext)
val sqlContext = SQLContext.getOrCreate(data.sparkContext)
import sqlContext.implicits._
val df = data.toDF()
val numFeatures = data.first().features.size
......
......@@ -29,7 +29,7 @@ import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
class CrossValidatorSuite
......@@ -39,7 +39,6 @@ class CrossValidatorSuite
override def beforeAll(): Unit = {
super.beforeAll()
val sqlContext = new SQLContext(sc)
dataset = sqlContext.createDataFrame(
sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment