5 単回帰分析
5.1 単回帰分析
教師有り学習の中でも、分類を行うのでは無く、数値を導き出すのが回帰分析です。例えば気温が上がればアイスクリームの売り上げが伸びる、という場合、気温がどれぐらいならアイスクリームの売り上げがどれぐらいになるか、を予測することができます。
この場合、結果は分類のように数パターンでは無く、あらゆる数値があり得ますので、使うアルゴリズムも変わってきます。 そこで使うのが回帰分析です。
まずは今回使うデータを読み込み表示してみます。気温とアイスクリームの売り上げです。 ファイル ice.ipynb に以下を記述します。
import pandas as pd
df = pd.read_csv("ice.csv")
df
今回は気温を元に売上を予測します。 気温と売上には相関関係があります。散布図で確認してみましょう。
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(font=["Meiryo"])
df.plot.scatter(x="気温", y="売上")
plt.show()
気温が上がると売上が上がります。 回帰直線を引いてみましょう。
sns.regplot(data=df, x="気温", y="売上", line_kws={"color":"red"})
plt.show()
このような場合、この場合、元となる特徴量が1つ(気温)しかありません。このような回帰分析を単回帰分析といいます。特徴量が複数の場合、重回帰分析と言います。
5.2 外れ値の削除
散布図を見ると外れ値が左上にあることに気がつきます。回帰分析の場合、外れ値に大きく影響されやすいため、原因を確認したうえで削除や別処理を行うことがあります。今回は、明らかに不自然な外れ値を除外します。 外れ値を探し、そのインデックスを求めてdropで削除します。
# 気温25度以下 かつ 売上80以上
no = df[(df["気温"]<=25) & (df["売上"]>=80)].index
df = df.drop(no)
5.3 モデルの構築と学習
まずは単回帰分析を行ってみましょう。特徴量を x に、正解ラベルを y に入れます。
x = df[["気温"]]
y = df["売上"]
また、学習用とテスト用のデータに分割します。
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state=0)
今回使用するアルゴリズムは線形回帰(LinearRegression)です。インポートし、学習します。
# 線形回帰モデル
from sklearn.linear_model import LinearRegression
model = LinearRegression()
# 学習する
model.fit(x_train, y_train)
線形回帰分析は以下のような式を作り予測をします。
特徴量×A+B=正解
では予測を行ってみましょう。
model.predict([[32.0]])
5.4 決定係数
では、評価を出してみましょう。回帰の評価は分類の場合と違い、正解率ではありません。決定係数という値になり、この値が1.0に近いほどモデルの式が実際のデータに当てはまりやすくなります。
# 決定係数
model.score(x_test, y_test)
決定係数は自然科学の分野では0.9以上が望ましいですが、人間の行動が変動要因である社会科学・経済学などの分野では0.3~0.5程度でも有意義とされますので、問題ない数字が出ているはずです。
また、どれぐらい誤差があったか、その平均値を「平均絶対誤差」(mean absolute error)として表示できます。
# 予測結果
y_pred = model.predict(x_test)
# 平均絶対誤差
from sklearn.metrics import mean_absolute_error
mean_absolute_error(y_test, y_pred)
5.5 回帰式
線形回帰分析は以下のような式(回帰式)を作り予測をします。
- 特徴量×係数+切片
係数は model.coef_ 、切片は model.intercept_ で表示できます。
print("係数", model.coef_)
print("切片", model.intercept_)
切片の意味はグラフの範囲を広げると分かります。xが0のときにyが何になるかです
plt.xlim(0, 35) # x軸の範囲
plt.ylim(-50, 90) # y軸の範囲
sns.regplot(data=df, x="気温", y="売上", line_kws={"color":"red"}, truncate=False)
plt.show()
例えば係数が3.44、切片が-31.32 の場合、xが0のときはyは-31.32です。xが1増える毎に3.44ずつ増えていきます。
ですので、係数3.44、切片-31.32 の場合、以下のようにして売上を計算できます。
気温×3.44 + (-31.32)