Scikit学习OneHotEncoder拟合和变换错误:ValueError:X的形状与拟合期间不同

问题描述:

下面是我的代码.

我知道为什么在转换过程中会发生错误.这是因为在拟合和变换过程中要素列表不匹配. 我该如何解决?我如何才能将其余所有功能都设为0?

I know why the error is occurring during transform. It is because of the feature list mismatch during fit and transform. How can i solve this? How can i get 0 for all the rest features?

在此之后,我想将其用于SGD分类器的局部拟合.

After this i want to use this for partial fit of SGD classifier.

Jupyter QtConsole 4.3.1
Python 3.6.2 |Anaconda custom (64-bit)| (default, Sep 21 2017, 18:29:43) 
Type 'copyright', 'credits' or 'license' for more information
IPython 6.1.0 -- An enhanced Interactive Python. Type '?' for help.

import pandas as pd
from sklearn.preprocessing import OneHotEncoder

input_df = pd.DataFrame(dict(fruit=['Apple', 'Orange', 'Pine'], 
                             color=['Red', 'Orange','Green'],
                             is_sweet = [0,0,1],
                             country=['USA','India','Asia']))
input_df
Out[1]: 
    color country   fruit  is_sweet
0     Red     USA   Apple         0
1  Orange   India  Orange         0
2   Green    Asia    Pine         1



filtered_df = input_df.apply(pd.to_numeric, errors='ignore')
filtered_df.info()
# apply one hot encode
refreshed_df = pd.get_dummies(filtered_df)
refreshed_df
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3 entries, 0 to 2
Data columns (total 4 columns):
color       3 non-null object
country     3 non-null object
fruit       3 non-null object
is_sweet    3 non-null int64
dtypes: int64(1), object(3)
memory usage: 176.0+ bytes


Out[2]: 
   is_sweet  color_Green  color_Orange  color_Red  country_Asia  \
0         0            0             0          1             0   
1         0            0             1          0             0   
2         1            1             0          0             1   

   country_India  country_USA  fruit_Apple  fruit_Orange  fruit_Pine  
0              0            1            1             0           0  
1              1            0            0             1           0  
2              0            0            0             0           1  



enc = OneHotEncoder()
enc.fit(refreshed_df)

Out[3]: 
OneHotEncoder(categorical_features='all', dtype=<class 'numpy.float64'>,
       handle_unknown='error', n_values='auto', sparse=True)



new_df = pd.DataFrame(dict(fruit=['Apple'], 
                             color=['Red'],
                             is_sweet = [0],
                             country=['USA']))
new_df


Out[4]: 
  color country  fruit  is_sweet
0   Red     USA  Apple         0



filtered_df1 = new_df.apply(pd.to_numeric, errors='ignore')
filtered_df1.info()
# apply one hot encode
refreshed_df1 = pd.get_dummies(filtered_df1)
refreshed_df1
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1 entries, 0 to 0
Data columns (total 4 columns):
color       1 non-null object
country     1 non-null object
fruit       1 non-null object
is_sweet    1 non-null int64
dtypes: int64(1), object(3)
memory usage: 112.0+ bytes



Out[5]: 
   is_sweet  color_Red  country_USA  fruit_Apple
0         0          1            1            1

enc.transform(refreshed_df1)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-6-33a6a884ba3f> in <module>()
----> 1 enc.transform(refreshed_df1)

~/anaconda3/lib/python3.6/site-packages/sklearn/preprocessing/data.py in transform(self, X)
   2073         """
   2074         return _transform_selected(X, self._transform,
-> 2075                                    self.categorical_features, copy=True)
   2076 
   2077 

~/anaconda3/lib/python3.6/site-packages/sklearn/preprocessing/data.py in _transform_selected(X, transform, selected, copy)
   1810 
   1811     if isinstance(selected, six.string_types) and selected == "all":
-> 1812         return transform(X)
   1813 
   1814     if len(selected) == 0:

~/anaconda3/lib/python3.6/site-packages/sklearn/preprocessing/data.py in _transform(self, X)
   2030             raise ValueError("X has different shape than during fitting."
   2031                              " Expected %d, got %d."
-> 2032                              % (indices.shape[0] - 1, n_features))
   2033 
   2034         # We use only those categorical features of X that are known using fit.

ValueError: X has different shape than during fitting. Expected 10, got 4.

而不是使用pd.get_dummies(),您需要

Instead of using pd.get_dummies() you need LabelEncoder + OneHotEncoder which can store the original values and then use them on the new data.

按如下所示更改代码将为您提供所需的结果.

Changing your code like below will give you required results.

import pandas as pd
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
input_df = pd.DataFrame(dict(fruit=['Apple', 'Orange', 'Pine'], 
                             color=['Red', 'Orange','Green'],
                             is_sweet = [0,0,1],
                             country=['USA','India','Asia']))

filtered_df = input_df.apply(pd.to_numeric, errors='ignore')

# This is what you need
le_dict = {}
for col in filtered_df.columns:
    le_dict[col] = LabelEncoder().fit(filtered_df[col])
    filtered_df[col] = le_dict[col].transform(filtered_df[col])

enc = OneHotEncoder()
enc.fit(filtered_df)
refreshed_df = enc.transform(filtered_df).toarray()

new_df = pd.DataFrame(dict(fruit=['Apple'], 
                             color=['Red'],
                             is_sweet = [0],
                             country=['USA']))
for col in new_df.columns:
    new_df[col] = le_dict[col].transform(new_df[col])

new_refreshed_df = enc.transform(new_df).toarray()

print(filtered_df)
      color  country  fruit  is_sweet
0      2        2      0         0
1      1        1      1         0
2      0        0      2         1

print(refreshed_df)
[[ 0.  0.  1.  0.  0.  1.  1.  0.  0.  1.  0.]
 [ 0.  1.  0.  0.  1.  0.  0.  1.  0.  1.  0.]
 [ 1.  0.  0.  1.  0.  0.  0.  0.  1.  0.  1.]]

print(new_df)
      color  country  fruit  is_sweet
0      2        2      0         0

print(new_refreshed_df)
[[ 0.  0.  1.  0.  0.  1.  1.  0.  0.  1.  0.]]