회귀 분석(DataFrame)
코드 예시(notebook 코드 참조)
import org.apache.spark._
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.LinearRegression
sc
)가 아니라 SparkSession(spark
)를 사용해야 데이터프레임을 사용할 수 있다.// SparkSession available as 'spark'
val inputLines = spark.sparkContext.textFile("data/regression.txt")
inputLines: org.apache.spark.rdd.RDD[String] = data/regression.txt MapPartitionsRDD[1] at textFile at <console>:35
val data = inputLines.map(_.split(",")).map(x => (x(0).toDouble, Vectors.dense(x(1).toDouble)))
data: org.apache.spark.rdd.RDD[(Double, org.apache.spark.ml.linalg.Vector)] = MapPartitionsRDD[7] at map at <console>:37
data.take(10)
res9: Array[(Double, org.apache.spark.ml.linalg.Vector)] = Array((-1.74,[1.66]), (1.24,[-1.18]), (0.29,[-0.4]), (-0.13,[0.09]), (-0.39,[0.38]), (-1.79,[1.73]), (0.71,[-0.77]), (1.39,[-1.48]), (1.15,[-1.43]), (0.13,[-0.07]))
import spark.implicits._
val colNames = Seq("label", "features")
val df = data.toDF(colNames: _*)
import spark.implicits._
colNames: Seq[String] = List(label, features)
df: org.apache.spark.sql.DataFrame = [label: double, features: vector]
df.show()
+-----+--------+
|label|features|
+-----+--------+
|-1.74| [1.66]|
| 1.24| [-1.18]|
| 0.29| [-0.4]|
|-0.13| [0.09]|
|-0.39| [0.38]|
|-1.79| [1.73]|
| 0.71| [-0.77]|
| 1.39| [-1.48]|
| 1.15| [-1.43]|
| 0.13| [-0.07]|
| 0.05| [-0.07]|
| 1.9| [-1.8]|
| 1.48| [-1.42]|
| 0.32| [-0.3]|
|-1.11| [1.0]|
| 0.51| [-0.62]|
|-1.58| [1.45]|
|-0.46| [0.44]|
|-0.49| [0.37]|
| 0.31| [-0.3]|
+-----+--------+
only showing top 20 rows
val trainTest = df.randomSplit(Array(0.5, 0.5))
val trainingDF = trainTest(0)
val testDF = trainTest(1)
trainTest: Array[org.apache.spark.sql.Dataset[org.apache.spark.sql.Row]] = Array([label: double, features: vector], [label: double, features: vector])
trainingDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, features: vector]
testDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, features: vector]
val lir = new LinearRegression()
.setRegParam(0.3) // regularization
.setElasticNetParam(0.8) // elastic net mixing
.setMaxIter(100) // max iterations
.setTol(1E-6) // convergence tolerance
lir: org.apache.spark.ml.regression.LinearRegression = linReg_a635d25c73c9
// Train the model using our training data
val model = lir.fit(trainingDF)
2019-04-05 21:33:55 WARN BLAS:61 - Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
2019-04-05 21:33:55 WARN BLAS:61 - Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
model: org.apache.spark.ml.regression.LinearRegressionModel = linReg_a635d25c73c9
val fullPredictions = model.transform(testDF).cache()
fullPredictions: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, features: vector ... 1 more field]
fullPredictions.show()
+-----+--------+-------------------+
|label|features| prediction|
+-----+--------+-------------------+
|-3.74| [3.75]| -2.697115171264099|
|-2.36| [2.63]| -1.892269537161224|
|-2.09| [1.97]|-1.4179855027791723|
|-2.07| [2.04]| -1.468288354910602|
| -2.0| [2.02]|-1.4539161114444794|
|-1.94| [1.98]|-1.4251716245122337|
|-1.91| [1.83]| -1.317379798516313|
|-1.91| [1.86]|-1.3389381637154971|
|-1.87| [1.98]|-1.4251716245122337|
| -1.8| [1.84]|-1.3245659202493745|
|-1.75| [1.69]|-1.2167740942534535|
|-1.74| [1.66]|-1.1952157290542693|
|-1.66| [1.64]|-1.1808434855881464|
|-1.65| [1.63]|-1.1736573638550851|
|-1.64| [1.84]|-1.3245659202493745|
|-1.61| [1.72]|-1.2383324594526377|
|-1.53| [1.68]|-1.2095879725203922|
|-1.47| [1.46]|-1.0514932943930415|
|-1.42| [1.59]|-1.1449128769228398|
| -1.4| [1.32]|-0.9508875901301823|
+-----+--------+-------------------+
only showing top 20 rows
val predictionAndLabel = fullPredictions.select("prediction", "label")
.rdd.map(x => (x.getDouble(0), x.getDouble(1)))
predictionAndLabel: org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[66] at map at <console>:70
// (y pred, y true)
predictionAndLabel.take(10).foreach(println)
(-2.697115171264099,-3.74)
(-1.892269537161224,-2.36)
(-1.4179855027791723,-2.09)
(-1.468288354910602,-2.07)
(-1.4539161114444794,-2.0)
(-1.4251716245122337,-1.94)
(-1.317379798516313,-1.91)
(-1.3389381637154971,-1.91)
(-1.4251716245122337,-1.87)
(-1.3245659202493745,-1.8)
// Stop the session
spark.stop()