package Spark_MLlib
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
/**
* 调参+模型选择
*/
case class schema_source(features:Vector,label:String)
object 交叉验证_调参_逻辑回归 {
val spark=SparkSession.builder().master("local[2]").getOrCreate()
import spark.implicits._
def main(args: Array[String]): Unit = {
val data=spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo.txt")
.map(_.split(",")).map(x=>schema_source(Vectors.dense(x(0).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF()
data.show()
val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data)
val featuresIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(data)
val Array(trainData,testData)=data.randomSplit(Array(0.7,0.3))
val lr=new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(50)
val labelConverter=new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
labelIndexer.labels.foreach(println)
//机器学习工作流
val lrPipeline=new Pipeline().setStages(Array(labelIndexer,featuresIndexer,lr,labelConverter))
//交叉验证需要的模型评估
val evaluator=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
//构造参数网格
val paramGrid=new ParamGridBuilder().addGrid(lr.regParam,Array(0.01,0.3,0.8)).addGrid(lr.elasticNetParam,Array(0.3,0.9)).build()
//构建机器学习工作流的交叉验证,定义验证模型,模型评估,参数网格,数据集的折叠数(交叉验证原理)
val cv=new CrossValidator().setEstimator(lrPipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid).setNumFolds(3)
//训练模型
val cvModel=cv.fit(trainData)
//测试数据
val lrPrediction=cvModel.transform(testData)
lrPrediction.show()
val evaluator2=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
val lrAccuracy=evaluator2.evaluate(lrPrediction)
println("准确率为: "+lrAccuracy)
println("错误率为: "+(1-lrAccuracy))
//获取最优模型
val bestModel=cvModel.bestModel.asInstanceOf[PipelineModel]
val lrModel=bestModel.stages(2).asInstanceOf[LogisticRegressionModel]
println("二项逻辑回归模型系数矩阵: "+lrModel.coefficientMatrix)
println("二项逻辑回归模型的截距向量: "+lrModel.interceptVector)
println("类的数量(标签可以使用的值): "+lrModel.numClasses)
println("模型所接受的特征的数量: "+lrModel.numFeatures)
println("所有参数的设置为: "+lrModel.explainParams())
println("最优的regParam的值为: "+lrModel.explainParam(lrModel.regParam))
println("最优的elasticNetParam的值为: "+lrModel.explainParam(lrModel.elasticNetParam))
}
}
+-----------------+-----+
| features|label|
+-----------------+-----+
|[5.1,3.5,1.4,0.2]|soyo1|
|[4.9,3.0,1.4,0.2]|soyo1|
|[4.7,3.2,1.3,0.2]|soyo1|
|[4.6,3.1,1.5,0.2]|soyo1|
|[5.0,3.6,1.4,0.2]|soyo1|
|[5.4,3.9,1.7,0.4]|soyo1|
|[4.6,3.4,1.4,0.3]|soyo1|
|[5.0,3.4,1.5,0.2]|soyo1|
|[4.4,2.9,1.4,0.2]|soyo1|
|[4.9,3.1,1.5,0.1]|soyo1|
|[5.4,3.7,1.5,0.2]|soyo1|
|[4.8,3.4,1.6,0.2]|soyo1|
|[4.8,3.0,1.4,0.1]|soyo1|
|[4.3,3.0,1.1,0.1]|soyo1|
|[5.8,4.0,1.2,0.2]|soyo1|
|[5.7,4.4,1.5,0.4]|soyo1|
|[5.4,3.9,1.3,0.4]|soyo1|
|[5.1,3.5,1.4,0.3]|soyo1|
|[5.7,3.8,1.7,0.3]|soyo1|
|[5.1,3.8,1.5,0.3]|soyo1|
+-----------------+-----+
only showing top 20 rows
soyo2
soyo1
soyo3
+-----------------+-----+------------+-----------------+--------------------+--------------------+----------+--------------+
| features|label|indexedLabel| indexedFeatures| rawPrediction| probability|prediction|predictedLabel|
+-----------------+-----+------------+-----------------+--------------------+--------------------+----------+--------------+
|[4.3,3.0,1.1,0.1]|soyo1| 1.0|[4.3,3.0,1.1,0.1]|[-0.2949197997435...|[0.00821657808181...| 1.0| soyo1|
|[4.4,2.9,1.4,0.2]|soyo1| 1.0|[4.4,2.9,1.4,0.2]|[-0.1436502505351...|[0.02310764702310...| 1.0| soyo1|
|[4.6,3.1,1.5,0.2]|soyo1| 1.0|[4.6,3.1,1.5,0.2]|[-0.1980725396328...|[0.01584026165726...| 1.0| soyo1|
|[4.8,3.0,1.4,0.1]|soyo1| 1.0|[4.8,3.0,1.4,0.1]|[-0.0360182992158...|[0.01909506488946...| 1.0| soyo1|
|[4.8,3.1,1.6,0.2]|soyo1| 1.0|[4.8,3.1,1.6,0.2]|[-0.0963956817735...|[0.02165865158723...| 1.0| soyo1|
|[4.8,3.4,1.6,0.2]|soyo1| 1.0|[4.8,3.4,1.6,0.2]|[-0.3305444022091...|[0.00764403083532...| 1.0| soyo1|
|[4.9,2.4,3.3,1.0]|soyo2| 0.0|[4.9,2.4,3.3,1.0]|[0.64687664475266...|[0.83588965920895...| 0.0| soyo2|
|[4.9,3.0,1.4,0.2]|soyo1| 1.0|[4.9,3.0,1.4,0.2]|[0.00894554123863...|[0.02696343238302...| 1.0| soyo1|
|[5.0,3.5,1.6,0.6]|soyo1| 1.0|[5.0,3.5,1.6,0.6]|[-0.3209967599706...|[0.01781564148264...| 1.0| soyo1|
|[5.0,3.6,1.4,0.2]|soyo1| 1.0|[5.0,3.6,1.4,0.2]|[-0.4132228265822...|[0.00370148550004...| 1.0| soyo1|
|[5.1,3.7,1.5,0.4]|soyo1| 1.0|[5.1,3.7,1.5,0.4]|[-0.4380550804437...|[0.00533390253840...| 1.0| soyo1|
|[5.1,3.8,1.9,0.4]|soyo1| 1.0|[5.1,3.8,1.9,0.4]|[-0.4784298068885...|[0.00593236888116...| 1.0| soyo1|
|[5.2,2.7,3.9,1.4]|soyo2| 0.0|[5.2,2.7,3.9,1.4]|[0.60296648363520...|[0.65499655703255...| 0.0| soyo2|
|[5.2,3.5,1.5,0.2]|soyo1| 1.0|[5.2,3.5,1.5,0.2]|[-0.2334963952443...|[0.00721300202565...| 1.0| soyo1|
|[5.3,3.7,1.5,0.2]|soyo1| 1.0|[5.3,3.7,1.5,0.2]|[-0.3434664691509...|[0.00396451436269...| 1.0| soyo1|
|[5.4,3.4,1.5,0.4]|soyo1| 1.0|[5.4,3.4,1.5,0.4]|[-0.0655191408567...|[0.02050202848213...| 1.0| soyo1|
|[5.4,3.4,1.7,0.2]|soyo1| 1.0|[5.4,3.4,1.7,0.2]|[-0.0443512521479...|[0.01568504280438...| 1.0| soyo1|
|[5.4,3.9,1.3,0.4]|soyo1| 1.0|[5.4,3.9,1.3,0.4]|[-0.4746044317663...|[0.00285607924154...| 1.0| soyo1|
|[5.4,3.9,1.7,0.4]|soyo1| 1.0|[5.4,3.9,1.7,0.4]|[-0.4369295847326...|[0.00451151133277...| 1.0| soyo1|
|[5.5,2.3,4.0,1.3]|soyo2| 0.0|[5.5,2.3,4.0,1.3]|[1.06413594105520...|[0.51327715648015...| 0.0| soyo2|
+-----------------+-----+------------+-----------------+--------------------+--------------------+----------+--------------+
only showing top 20 rows
准确率为: 0.9418343292582645
错误率为: 0.05816567074173551
二项逻辑回归模型系数矩阵: 0.4612907305046201 -0.7804957347855317 0.09418711758439907 -0.011652325959556013
-0.559055378870932 2.7385209747134933 -1.052922922424876 -2.5223769474140303
-0.07629895224519458 -3.6867236615320547 1.0014498171011217 4.581938360185545
二项逻辑回归模型的截距向量: [-0.039423333303658874,0.0972586768296292,-0.05783534352597033]
类的数量(标签可以使用的值): 3
模型所接受的特征的数量: 4
所有参数的设置为: aggregationDepth: suggested depth for treeAggregate (>= 2) (default: 2)
elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0, current: 0.9)
family: The name of family which is a description of the label distribution to be used in the model. Supported options: auto, binomial, multinomial. (default: auto)
featuresCol: features column name (default: features, current: indexedFeatures)
fitIntercept: whether to fit an intercept term (default: true)
labelCol: label column name (default: label, current: indexedLabel)
lowerBoundsOnCoefficients: The lower bounds on coefficients if fitting under bound constrained optimization. (undefined)
lowerBoundsOnIntercepts: The lower bounds on intercepts if fitting under bound constrained optimization. (undefined)
maxIter: maximum number of iterations (>= 0) (default: 100, current: 50)
predictionCol: prediction column name (default: prediction)
probabilityCol: Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities (default: probability)
rawPredictionCol: raw prediction (a.k.a. confidence) column name (default: rawPrediction)
regParam: regularization parameter (>= 0) (default: 0.0, current: 0.01)
standardization: whether to standardize the training features before fitting the model (default: true)
threshold: threshold in binary classification prediction, in range [0, 1] (default: 0.5)
thresholds: Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values > 0 excepting that at most one value may be 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class's threshold (undefined)
tol: the convergence tolerance for iterative algorithms (>= 0) (default: 1.0E-6)
upperBoundsOnCoefficients: The upper bounds on coefficients if fitting under bound constrained optimization. (undefined)
upperBoundsOnIntercepts: The upper bounds on intercepts if fitting under bound constrained optimization. (undefined)
weightCol: weight column name. If this is not set or empty, we treat all instance weights as 1.0 (undefined)
最优的regParam的值为: regParam: regularization parameter (>= 0) (default: 0.0, current: 0.01)
最优的elasticNetParam的值为: elasticNetParam: the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty (default: 0.0, current: 0.9)