如何在预测结果中添加一些自定义字段(即用户ID)?
List<org.apache.spark.mllib.regression.LabeledPoint> localTesting = ... ;//
// I want to add some identifier to each LabeledPoint
DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class);
DataFrame predictions = model.transform(localTestDF);
Row[] collect = predictions.select("label", "probability", "prediction").collect();
for (Row r : collect) {
// and want to return identifier here.
// so do I save I to database.
int userNo = Integer.parseInt(r.get(0).toString());
double prob = Double.parseDouble(r.get(1).toString());
int prediction = Integer.parseInt(r.get(2).toString());
log.debug(userNo + "," + prob + ", " + prediction);
}
但是当我使用这个类进行localTesting而不是LabeledPoint时,
class NoLabeledPoint extends LabeledPoint implements Serializable {
private static final long serialVersionUID = -2488661810406135403L;
int userNo;
public NoLabeledPoint(double label, Vector features) {
super(label, features);
}
public int getUserNo() {
return userNo;
}
public void setUserNo(int userNo) {
this.userNo = userNo;
}
}
List<NoLabeledPoint> localTesting = ... ;// set every user'no to the field userNo
// I want to add some identifier to each LabeledPoint
DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class);
DataFrame predictions = model.transform(localTestDF);
Row[] collect = predictions.select("userNo", "probability", "prediction").collect();
for (Row r : collect) {
// and want to return identifier here.
// so do I save I to database.
int userNo = Integer.parseInt(r.get(0).toString());
double prob = Double.parseDouble(r.get(1).toString());
int prediction = Integer.parseInt(r.get(2).toString());
log.debug(userNo + "," + prob + ", " + prediction);
}
异常抛出
org.apache.spark.sql.AnalysisException: cannot resolve 'userNo' given input columns rawPrediction, probability, features, label, prediction;
at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:63)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$2.applyOrElse(CheckAnalysis.scala:52)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:286)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:286)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:51)
我的意思是我不仅要获得预测数据(功能,标签,概率……),还要获得我想要的一些自定义字段.例如userNo,user_id等
从结果:predictions.select(“……”)
更新
解决了.一行应该是固定的
从
DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), LabeledPoint.class);
至
DataFrame localTestDF = jsql.createDataFrame(jsc.parallelize(studyData.localTesting), NoLabeledPoint.class);
最佳答案 由于您不使用低级别的MLlib API,因此根本不需要使用LabeledPoint.在创建DataFrame之后,您只需使用具有特定值的Row,并且所有重要的*是与管道中的参数匹配的类型和列名称.
在Scala中,您可以使用任何案例类
org.apache.spark.mllib.linalg.Vector; case class
case class LabeledPointWithMeta(userNo: String, label: Double, features: Vector)
val rdd: RDD[LabeledPointWithMeta] = ???
val df = rdd.toDF
为了能够使用它,你应该添加@BeanInfo注释:
import scala.beans.BeanInfo
@BeanInfo
case class LabeledPointWithMeta(...)
基于Spark SQL and DataFrame Guide它看起来像普通的Java你可以做这样的事情**:
import org.apache.spark.mllib.linalg.Vector;
public static class LabeledPointWithMeta implements Serializable {
private int userNo;
private double label;
private Vector vector;
public int getUserNo() {
return userNo;
}
public void setUserNo(int userNo) {
this.userNo = userNo;
}
public double getLabel() {
return label;
}
public void setLabel(double label) {
this.label = label;
}
public Vector getVector() {
return vector;
}
public void seVector(Vector vector) {
this.vector = vector;
}
}
在那之后:
JavaRDD<LabeledPointWithMeta> myPoints = ...;
DataFrame df = sqlContext.createDataFrame(myPoints LabeledPointWithMeta.class);
我认为代码中的一个简单更改也应该起作用:
DataFrame localTestDF = jsql.createDataFrame(
jsc.parallelize(studyData.localTesting),
NoLabeledPoint.class
);
如果你想使用MLlib,它对你没有帮助,但是这个部分可以通过简单的RDD转换(如zip)轻松处理.
*一些元数据,但你不会从LabeledPoint获得
**我没有测试上面的代码,所以它可能包含一些错误.