diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index c96f99cb83434a32bd19717fc84fbe604c79c9bc..703bcdf4ca72525f765931b7ca2302d7dd02d3e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -40,13 +40,13 @@ private[r] class LogisticRegressionWrapper private ( private val lrModel: LogisticRegressionModel = pipeline.stages(1).asInstanceOf[LogisticRegressionModel] - val rFeatures: Array[String] = if (lrModel.getFitIntercept) { + lazy val rFeatures: Array[String] = if (lrModel.getFitIntercept) { Array("(Intercept)") ++ features } else { features } - val rCoefficients: Array[Double] = { + lazy val rCoefficients: Array[Double] = { val numRows = lrModel.coefficientMatrix.numRows val numCols = lrModel.coefficientMatrix.numCols val numColsWithIntercept = if (lrModel.getFitIntercept) numCols + 1 else numCols diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index d34de30931143e6de2443c59f3b0eb13cf430412..48c87743dee605773f4951918d01b9ef4beb6f1f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -36,11 +36,11 @@ private[r] class MultilayerPerceptronClassifierWrapper private ( import MultilayerPerceptronClassifierWrapper._ - val mlpModel: MultilayerPerceptronClassificationModel = + private val mlpModel: MultilayerPerceptronClassificationModel = pipeline.stages(1).asInstanceOf[MultilayerPerceptronClassificationModel] - val weights: Array[Double] = mlpModel.weights.toArray - val layers: Array[Int] = mlpModel.layers + lazy val weights: Array[Double] = mlpModel.weights.toArray + lazy val layers: Array[Int] = mlpModel.layers def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset)