sklearn中数据缩放用到的fit_transform()、transform()、fit()方法的区别与联系

sklearn中数据缩放用到的fit_transform()、transform()、fit()方法的区别与联系

看了一堆搜索排名靠前的中文博客,感觉没有一个解释能让人醍醐灌顶的,故搜索英文网页并记之。

谢绝转载。

首先对于数据标准化一般是这么做的:

sklearn中数据缩放用到的fit_transform()、transform()、fit()方法的区别与联系

 其中σ是标准差。目的是使数据服从均值为零,标准差为1的标准正态分布,此即标准化(Standardization)。

标准化都是给训练集数据做的,但在以下情况中也必须做数据标准化,比如,交叉验证时的测试集,或者是预测前获得了一组新的样本。而在对新的数据或测试集进行标准化时,我们所用的是训练集标准化中的均值μ和标准差σ。

因此,StandardScaler 中的fit()所做的就是计算数据的均值μ和标准差σ,并将他们储存为一个内部对象的状态,无返回值。然后,对测试集调用transform()方法,此方法将使用刚刚fit()计算得到的均值μ和标准差σ来对测试集数据进行标准化。

而fit_transform()就是将以上两步二合一,因为其内部就是先后调用fit()和transform()函数的。

所以我们经常能看到类似这样的代码:

1 # Feature Scaling
2 from sklearn.preprocessing import StandardScaler
3 sc = StandardScaler()
4 X_train = sc.fit_transform(X_train)
5 X_test = sc.transform(X_test)

注意这里fit_transform()是用在训练集上的,也就是说,fit_transform()先计算了训练集数据的均值μ和标准差σ,并以此对训练集进行标准化。

参考:

https://datascience.stackexchange.com/questions/12321/whats-the-difference-between-fit-and-fit-transform-in-scikit-learn-models

https://www.kaggle.com/questions-and-answers/58368