Skip to content
Snippets Groups Projects
Commit c1080b6f authored by DB Tsai's avatar DB Tsai Committed by Xiangrui Meng
Browse files

[SPARK-7568] [ML] ml.LogisticRegression doesn't output the right prediction

The difference is because we previously don't fit the intercept in Spark 1.3. Here, we change the input `String` so that the probability of instance 6 can be classified as `1.0` without any ambiguity.

with lambda = 0.001 in current LOR implementation, the prediction is
```
(4, spark i j k) --> prob=[0.1596407738787411,0.8403592261212589], prediction=1.0
(5, l m n) --> prob=[0.8378325685476612,0.16216743145233883], prediction=0.0
(6, spark hadoop spark) --> prob=[0.0692663313297627,0.9307336686702373], prediction=1.0
(7, apache hadoop) --> prob=[0.9821575333444208,0.01784246665557917], prediction=0.0
```
and the training accuracy is
```
(0, a b c d e spark) --> prob=[0.0021342419881406746,0.9978657580118594], prediction=1.0
(1, b d) --> prob=[0.9959176174854043,0.004082382514595685], prediction=0.0
(2, spark f g h) --> prob=[0.0014541569986711233,0.9985458430013289], prediction=1.0
(3, hadoop mapreduce) --> prob=[0.9982978367343561,0.0017021632656438518], prediction=0.0
```

Author: DB Tsai <dbt@netflix.com>

Closes #6109 from dbtsai/lor-example and squashes the following commits:

ac63ce4 [DB Tsai] first commit
parent 1b8625f4
No related branches found
No related tags found
No related merge requests found
......@@ -66,7 +66,7 @@ public class JavaSimpleTextClassificationPipeline {
.setOutputCol("features");
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.01);
.setRegParam(0.001);
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
......@@ -77,7 +77,7 @@ public class JavaSimpleTextClassificationPipeline {
List<Document> localTest = Lists.newArrayList(
new Document(4L, "spark i j k"),
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(6L, "spark hadoop spark"),
new Document(7L, "apache hadoop"));
DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
......
......@@ -48,7 +48,7 @@ if __name__ == "__main__":
# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=10, regParam=0.01)
lr = LogisticRegression(maxIter=10, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
# Fit the pipeline to training documents.
......@@ -58,7 +58,7 @@ if __name__ == "__main__":
Document = Row("id", "text")
test = sc.parallelize([(4, "spark i j k"),
(5, "l m n"),
(6, "mapreduce spark"),
(6, "spark hadoop spark"),
(7, "apache hadoop")]) \
.map(lambda x: Document(*x)).toDF()
......
......@@ -64,7 +64,7 @@ object SimpleTextClassificationPipeline {
.setOutputCol("features")
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.01)
.setRegParam(0.001)
val pipeline = new Pipeline()
.setStages(Array(tokenizer, hashingTF, lr))
......@@ -75,7 +75,7 @@ object SimpleTextClassificationPipeline {
val test = sc.parallelize(Seq(
Document(4L, "spark i j k"),
Document(5L, "l m n"),
Document(6L, "mapreduce spark"),
Document(6L, "spark hadoop spark"),
Document(7L, "apache hadoop")))
// Make predictions on test documents.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment