Skip to content
Snippets Groups Projects
Unverified Commit f3fe5543 authored by Sean Owen's avatar Sean Owen
Browse files

[SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to...

[SPARK-10835][ML] Word2Vec should accept non-null string array, in addition to existing null string array

## What changes were proposed in this pull request?

To match Tokenizer and for compatibility with Word2Vec, output a nullable string array type in NGram

## How was this patch tested?

Jenkins tests.

Author: Sean Owen <sowen@cloudera.com>

Closes #15179 from srowen/SPARK-10835.
parent 7c382524
No related branches found
No related tags found
No related merge requests found
...@@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params ...@@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params
* Validate and transform the input schema. * Validate and transform the input schema.
*/ */
protected def validateAndTransformSchema(schema: StructType): StructType = { protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false))
SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
} }
} }
......
...@@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ...@@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val newInstance = testDefaultReadWrite(instance) val newInstance = testDefaultReadWrite(instance)
assert(newInstance.getVectors.collect() === instance.getVectors.collect()) assert(newInstance.getVectors.collect() === instance.getVectors.collect())
} }
test("Word2Vec works with input that is non-nullable (NGram)") {
val spark = this.spark
import spark.implicits._
val sentence = "a q s t q s t b b b s t m s t m q "
val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text")
val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams")
val ngramDF = ngram.transform(docDF)
val model = new Word2Vec()
.setVectorSize(2)
.setInputCol("ngrams")
.setOutputCol("result")
.fit(ngramDF)
// Just test that this transformation succeeds
model.transform(ngramDF).collect()
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment