fit、transform、fit_transformの意味を、正規化の例で解説

最終更新日 2019/05/12

sklearnにおけるデータの変換を行う機能では、
・fit:変換式を計算する
・transform:変換式を使ってデータを変換
・fit_transform:上記2つをまとめて実行

という3つの関数がセットで用意されていることが多いです。

最大値を1に、最小値を0にする正規化(MM正規化)の例を使って3つの関数の役割を解説します(sklearn では MinMaxScaler に対応します)。

fitの役割

fit は変換式を計算します。(データを変換するために必要な統計情報を計算します)

例えば、$(0,1,2)$ というデータに対して、MM正規化における fit 関数を適用すると、変換式 $y=\dfrac{1}{2}x$ を計算してくれます。この際、変換は行われません。

transformの役割

transform は fit の結果を使って、実際にデータを変換します。

例えば、$(0,1,2)$ というデータに対して、fit 関数を適用した後に transform 関数を適用すると、
$(0,1,2)$ が $(0,0.5,1)$ というデータに変換されます。

fit_transform の役割

fit_transform は、fit と transform をまとめて行います。

例えば、$(0,1,2)$ というデータに対して、fit_transform 関数を適用すると、
$(0,1,2)$ が一発で $(0,0.5,1)$ というデータに変換されます。

fit_transform だけあれば、残り2つは必要無いのでは?という気もしますが、そうではありません。理由を以下で説明します。

注意点

テストデータの変換は、訓練データで用いた変換式を使って行う必要があります。

例えば、データ数3の訓練データ $(0,1,2)$ を正規化した $(0,0.5,1)$ を使って回帰モデルを作成したとします。

そして、そのモデルを使ってテストデータの予測をする際には、訓練データに適用したものと同じ変換を行う必要があります。

例えば、テストデータがデータ数2で $(0.5,1.5)$ の場合、これをもとにした変換式 $y=x-0.5$ で正規化して $(0,1)$ としてモデルに適用してはいけません。訓練データに適用した変換式 $y=\dfrac{1}{2}x$ を適用して、$(0.25,0.75)$ とする必要があります。

つまり、訓練データに fit を適用して求めた変換式を使って、テストデータの transform を行うという流れが一般的になります。

fit、transform、fit_transform を有するクラスの例

・sklearn.preprocessing.MinMaxScaler
最大値と最小値が指定した値になるような線形変換を行います。

・sklearn.preprocessing.StandardScaler
平均が $0$、標準偏差が $1$ になるような線形変換を行います。

・sklearn.decomposition.PCA
主成分分析を行います。

・sklearn.preprocessing.OneHotEncoder
カテゴリ変数からダミー変数を作成します。(One-Hot-Encoding)

次回は 統計における標準化の意味と目的 を解説します。

ページ上部へ戻る