One-vs-Rest
算法介绍:
OneVsRest将一个给定的二分类算法有效地扩展到多分类问题应用中,也叫做“One-vs-All.”算法。OneVsRest是一个Estimator。它采用一个基础的Classifier然后对于k个类别分别创建二分类问题。类别i的二分类分类器用来预测类别为i还是不为i,即将i类和其他类别区分开来。最后,通过依次对k个二分类分类器进行评估,取置信最高的分类器的标签作为i类别的标签。
参数:
featuresCol:
类型:字符串型。
含义:特征列名。
labelCol:
类型:字符串型。
含义:标签列名。
predictionCol:
类型:字符串型。
含义:预测结果列名。
classifier:
类型:分类器型。
含义:基础二分类分类器。
示例:
Scala:
import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
// load data file. val inputData = spark.read.format("libsvm")
.load("data/mllib/sample_multiclass_classification_data.txt")
// generate the train/test split. val Array(train, test) = inputData.randomSplit(Array(0.8, 0.2))
// instantiate the base classifier val classifier = new LogisticRegression()
.setMaxIter(10)
.setTol(1E-6)
.setFitIntercept(true)
// instantiate the One Vs Rest Classifier. val ovr = new OneVsRest().setClassifier(classifier)
// train the multiclass model. val ovrModel = ovr.fit(train)
// score the model on test data. val predictions = ovrModel.transform(test)
// obtain evaluator. val evaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy")
// compute the classification error on test data. val accuracy = evaluator.evaluate(predictions)
println(s"Test Error : ${1 - accuracy}")
Java:
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
// load data file. Dataset<Row> inputData = spark.read().format("libsvm")
.load("data/mllib/sample_multiclass_classification_data.txt");
// generate the train/test split. Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> train = tmp[0];
Dataset<Row> test = tmp[1];
// configure the base classifier. LogisticRegression classifier = new LogisticRegression()
.setMaxIter(10)
.setTol(1E-6)
.setFitIntercept(true);
// instantiate the One Vs Rest Classifier. OneVsRest ovr = new OneVsRest().setClassifier(classifier);
// train the multiclass model. OneVsRestModel ovrModel = ovr.fit(train);
// score the model on test data. Dataset<Row> predictions = ovrModel.transform(test)
.select("prediction", "label");
// obtain evaluator. MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy");
// compute the classification error on test data. double accuracy = evaluator.evaluate(predictions);
System.out.println("Test Error : " + (1 - accuracy));
Python:
from pyspark.ml.classification import LogisticRegression, OneVsRest
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# load data file.
inputData = spark.read.format("libsvm") \
.load("data/mllib/sample_multiclass_classification_data.txt")
# generate the train/test split.
(train, test) = inputData.randomSplit([0.8, 0.2])
# instantiate the base classifier.
lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True)
# instantiate the One Vs Rest Classifier.
ovr = OneVsRest(classifier=lr)
# train the multiclass model.
ovrModel = ovr.fit(train)
# score the model on test data.
predictions = ovrModel.transform(test)
# obtain evaluator.
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")
# compute the classification error on test data.
accuracy = evaluator.evaluate(predictions)
print("Test Error : " + str(1 - accuracy))