如何在sklearn中获取预测值和误差度量

我有两个单独的
python函数,其中一个使用cross_val_predict返回数据集的预测值,另一个使用cross_validate返回多个错误度量值.下面显示的是用于获取度量值的方法(我已经实现了类似的方法来获取预测).

def metric_val(folds):
.
.
.
scoring = {'r_score': 'r2',
           'abs_error': 'neg_mean_absolute_error',
           'squared_error': 'neg_mean_squared_error'}

scores = cross_validate(best_svr, X, y, scoring=scoring, cv=folds, return_train_score=True)

print("****\nR2 :", "", scores['test_r_score'].mean(),
      "| MAE :", scores['test_abs_error'].mean(),
      )
return prediction

我不想同时使用这两个函数,因为它的计算成本很高.是否有单一的方法或另一种方法来获得预测和指标?

最佳答案 有可能装备一个得分手,以便它返回预测,虽然这有点像黑客.这是怎么做的:

cross_validate()函数可以使用自定义评分函数.评分函数必须返回一个数字,但您可以在函数内部执行任何操作.由于您拥有clf和所有测试数据,只需保存clf.predict()的输出,然后返回一个虚拟值以保持得分者满意.有关详细信息,请参阅Implementing your own scoring object上的sklearn docs.

像这样:

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split, cross_validate, cross_val_predict

# example data
iris = datasets.load_iris()
X, y = iris.data, iris.target 
clf = svm.SVC(probability=True, random_state=0)

定义自定义get_preds()函数,将其作为得分者隐藏:

def get_preds(clf, X, y): # y is required for a scorer but we won't use it
    with open("pred.csv", "ab+") as f: # append each fold to file
        np.savetxt(f, clf.predict(X))
    return 0

scoring = {'preds': get_preds,
           'accuracy': 'accuracy',
           'recall': 'recall_macro'} # add desired scorers here

k = 5
cross_validate(clf, X, y, 
               scoring=scoring, 
               return_train_score=True,
               cv = k)

重新加载get_preds(),重新整形以匹配折叠集,并在折叠中平均:

preds = np.loadtxt("pred.csv").reshape(k, len(X))
my_preds = np.mean(my_preds, axis=0).round()

与cross_val_predict()预测比较:

cv_preds = cross_val_predict(clf, X, y, cv=k)

np.equal(my_preds, cv_preds).sum() # 487 out of 500

我们在makehift get_preds()方法和cross_val_predict()之间看到了几乎完美的一致.小分歧可能是由于我的平均方法与cross_val_predict不同(我只是舍入到最接近的整数类,不是非常复杂),或者它可能与sklearn cross-validation docs中这个稍微神秘的音符有关:

Note that the result of this computation may be slightly different from those obtained using cross_val_score as the elements are grouped in different ways.

点赞