hiroportation

ITの話だったり、音楽の話、便利なガジェットの話題などを発信しています

機械学習アルゴリズムの復習(SVM)

SVMによるデータ分類を行います。

レーニングデータの準備

必要モジュールのimport

from sklearn import datasets
from sklearn import svm
import matplotlib.pyplot as plt
from sklearn import metrics

レーニングデータの準備

#データの準備
digits = datasets.load_digits()

# データ数の確認
n_samples = len(digits.data)
print("データ数:{}".format(n_samples))

#データの可視化
print(digits.data[0])

images_and_labels = list(zip(digits.images, digits.target))

for index, (image, label) in enumerate(images_and_labels[:10]): # enumerate:リストを順番に処理
    plt.subplot(2, 5, index + 1)
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.axis('off')
    plt.title('Training: %i'% label)
plt.show()

f:id:thelarklife1021:20211001054328p:plain:h300

SVM のロード

clf = svm.SVC(gamma=0.001, C=100.)

6割のデータで学習し、4割のデータでテストする場合

学習実行

# 60%のデータで学習実行
clf.fit(digits.data[:int(n_samples * 6 / 10)], digits.target[:int(n_samples * 6 / 10)])

テストを実行

# 40%のデータでテスト
expected = digits.target[int(n_samples *-4 / 10):]
predicted = clf.predict(digits.data[int(n_samples *-4 / 10):])

print(clf,metrics.classification_report(expected, predicted))
print(metrics.confusion_matrix(expected, predicted))

予測結果を可視化

images_and_predictions = list(zip(digits.images[int(n_samples *-4 / 10):], predicted))
for index,(image, prediction) in enumerate(images_and_predictions[:12]):
 plt.subplot(3, 4, index + 1)
 plt.axis('off')
 plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
 plt.title('Prediction: %i' % prediction)
plt.show()

下記の通り全て予測(Prediction)が(視認する限り)正確に学習されていることがわかる

f:id:thelarklife1021:20211001054918p:plain:h300

1割のデータで学習し、8割のデータでテストする場合

学習実行

# 10%のデータで学習実行
clf.fit(digits.data[:int(n_samples * 1 / 10)], digits.target[:int(n_samples * 1 / 10)])

テストを実行

# 90%のデータでテスト
expected = digits.target[int(n_samples *-9 / 10):]
predicted = clf.predict(digits.data[int(n_samples *-9 / 10):])

print(clf,metrics.classification_report(expected, predicted))
print(metrics.confusion_matrix(expected, predicted))

予測結果を可視化

images_and_predictions = list(zip(digits.images[int(n_samples *-9 / 10):], predicted))
for index,(image, prediction) in enumerate(images_and_predictions[:12]):
 plt.subplot(3, 4, index + 1)
 plt.axis('off')
 plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
 plt.title('Prediction: %i' % prediction)
plt.show()

f:id:thelarklife1021:20211001060348p:plain:h300

このくらいだと1回の学習でもほぼ正確に予測できているように見える。 次はもっと複雑な形を選びたいと思います。

以上