我正在尝试定义一个符合Sklearn估算器的类,例如:
class MyEstimator():
def __init__(self,verbose=False):
self.verbose = verbose
def get_params(self, deep=False):
return {
'verbose': self.verbose,
}
def set_params(self, **parameters):
for parameter, value in parameters.items():
setattr(self, parameter, value)
return self
# Also def fit() and other stuff ...
题
可以在不明确列出所有参数名称的情况下定义set_params().有没有办法以类似的方式定义get_params()?
我需要Sklearn的是GridsearchCV,根据我的尝试,似乎get_params确定在交叉验证期间可以注入哪些参数.
最佳答案 只需从
BaseEstimator继承您的类,它将为您实现get_params()和set_params().
演示:
In [21]: from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, ClusterMixin
In [22]: from sklearn.base import BaseEstimator
...:
...: class MyEstimator(BaseEstimator):
...: def __init__(self,verbose=False):
...: self.verbose = verbose
In [23]: est = MyEstimator(verbose=True)
In [24]: est.get_params()
Out[24]: {'verbose': True}
In [25]: est.set_params(verbose=False)
Out[25]: MyEstimator(verbose=False)
In [26]: est.get_params()
Out[26]: {'verbose': False}
PS你也可能想从(ClassifierMixin,RegressorMixin,ClusterMixin)中的一个继承你的估算器,这取决于你要实现的估算器类型……