One-vs-Rest算法介绍及Spark MLlib调用(Scala/Java/Python)

One-vs-Rest

算法介绍:

OneVsRest将一个给定的二分类算法有效地扩展到多分类问题应用中,也叫做“One-vs-All.”算法。OneVsRest是一个Estimator。它采用一个基础的Classifier然后对于k个类别分别创建二分类问题。类别i的二分类分类器用来预测类别为i还是不为i,即将i类和其他类别区分开来。最后,通过依次对k个二分类分类器进行评估,取置信最高的分类器的标签作为i类别的标签。

参数:

featuresCol:

类型:字符串型。

含义:特征列名。

labelCol:

类型:字符串型。

含义:标签列名。

predictionCol:

类型:字符串型。

含义:预测结果列名。

classifier:

类型:分类器型。

含义:基础二分类分类器。

示例:

Scala:

import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

// load data file. val inputData = spark.read.format("libsvm")
  .load("data/mllib/sample_multiclass_classification_data.txt")

// generate the train/test split. val Array(train, test) = inputData.randomSplit(Array(0.8, 0.2))

// instantiate the base classifier val classifier = new LogisticRegression()
  .setMaxIter(10)
  .setTol(1E-6)
  .setFitIntercept(true)

// instantiate the One Vs Rest Classifier. val ovr = new OneVsRest().setClassifier(classifier)

// train the multiclass model. val ovrModel = ovr.fit(train)

// score the model on test data. val predictions = ovrModel.transform(test)

// obtain evaluator. val evaluator = new MulticlassClassificationEvaluator()
  .setMetricName("accuracy")

// compute the classification error on test data. val accuracy = evaluator.evaluate(predictions)
println(s"Test Error : ${1 - accuracy}")

Java:

import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

// load data file. Dataset<Row> inputData = spark.read().format("libsvm")
  .load("data/mllib/sample_multiclass_classification_data.txt");

// generate the train/test split. Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2});
Dataset<Row> train = tmp[0];
Dataset<Row> test = tmp[1];

// configure the base classifier. LogisticRegression classifier = new LogisticRegression()
  .setMaxIter(10)
  .setTol(1E-6)
  .setFitIntercept(true);

// instantiate the One Vs Rest Classifier. OneVsRest ovr = new OneVsRest().setClassifier(classifier);

// train the multiclass model. OneVsRestModel ovrModel = ovr.fit(train);

// score the model on test data. Dataset<Row> predictions = ovrModel.transform(test)
  .select("prediction", "label");

// obtain evaluator. MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
        .setMetricName("accuracy");

// compute the classification error on test data. double accuracy = evaluator.evaluate(predictions);
System.out.println("Test Error : " + (1 - accuracy));

Python:

from pyspark.ml.classification import LogisticRegression, OneVsRest
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# load data file.
inputData = spark.read.format("libsvm") \
    .load("data/mllib/sample_multiclass_classification_data.txt")

# generate the train/test split.
(train, test) = inputData.randomSplit([0.8, 0.2])

# instantiate the base classifier.
lr = LogisticRegression(maxIter=10, tol=1E-6, fitIntercept=True)

# instantiate the One Vs Rest Classifier.
ovr = OneVsRest(classifier=lr)

# train the multiclass model.
ovrModel = ovr.fit(train)

# score the model on test data.
predictions = ovrModel.transform(test)

# obtain evaluator.
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")

# compute the classification error on test data.
accuracy = evaluator.evaluate(predictions)
print("Test Error : " + str(1 - accuracy))
    原文作者:刘玲源
    原文地址: https://zhuanlan.zhihu.com/p/24121959
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞