コンテンツにスキップ

第5章 あやめの分類

1 モデルの保存

まず、iris.ipynb ファイルで、あやめの機械学習を行い、モデルを作成します。

import pandas as pd
df =  pd.read_csv("iris.csv")

# 特徴量と正解ラベルに分割
x = df[["SepalLength", "SepalWidth", "PetalLength", "PetalWidth"]]
y = df["Name"]

# 学習用とテスト用の分割
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state=0)

# モデルのインポート
from sklearn import tree
model = tree.DecisionTreeClassifier()

# 学習する
model.fit(x_train, y_train)

# 結果の検証
model.score(x_test, y_test)

このモデルをstreamlitで使用するので、ファイルとして保存します。python標準のpickleというモジュールを使用すると、オブジェクトの保存が出来ます。

import pickle
with open('iris.pkl','wb') as file:
    pickle.dump(model, file)

2 ページの作成

st_iris.py ファイルに以下を作成します。

import streamlit as st

st.title('あやめの分類')

テキストボックスを配置します。

sepal_length = st.number_input('花がくの長さ')
sepal_width = st.number_input('花がくの幅')
petal_length = st.number_input('花びらの長さ')
petal_width = st.number_input('花びらの幅')

3 予測の実行

まず、モデルを読み込んでおきます。

# modelに読み込み
import pickle
with open('iris.pkl','rb') as f:
    model = pickle.load(f)

ボタンを押したら予測を実行します。

btn = st.button("予測")
if btn:
   # 予測
   result = model.predict([[sepal_length, sepal_width, petal_length, petal_width]])

   st.text(f"結果:{result[0]}")