SVMによるデータ分類を行います。
トレーニングデータの準備
必要モジュールのimport
from sklearn import datasets from sklearn import svm import matplotlib.pyplot as plt from sklearn import metrics
トレーニングデータの準備
#データの準備 digits = datasets.load_digits() # データ数の確認 n_samples = len(digits.data) print("データ数:{}".format(n_samples)) #データの可視化 print(digits.data[0]) images_and_labels = list(zip(digits.images, digits.target)) for index, (image, label) in enumerate(images_and_labels[:10]): # enumerate:リストを順番に処理 plt.subplot(2, 5, index + 1) plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.axis('off') plt.title('Training: %i'% label) plt.show()
SVM のロード
clf = svm.SVC(gamma=0.001, C=100.)
6割のデータで学習し、4割のデータでテストする場合
学習実行
# 60%のデータで学習実行 clf.fit(digits.data[:int(n_samples * 6 / 10)], digits.target[:int(n_samples * 6 / 10)])
テストを実行
# 40%のデータでテスト expected = digits.target[int(n_samples *-4 / 10):] predicted = clf.predict(digits.data[int(n_samples *-4 / 10):]) print(clf,metrics.classification_report(expected, predicted)) print(metrics.confusion_matrix(expected, predicted))
予測結果を可視化
images_and_predictions = list(zip(digits.images[int(n_samples *-4 / 10):], predicted)) for index,(image, prediction) in enumerate(images_and_predictions[:12]): plt.subplot(3, 4, index + 1) plt.axis('off') plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Prediction: %i' % prediction) plt.show()
下記の通り全て予測(Prediction)が(視認する限り)正確に学習されていることがわかる
1割のデータで学習し、8割のデータでテストする場合
学習実行
# 10%のデータで学習実行 clf.fit(digits.data[:int(n_samples * 1 / 10)], digits.target[:int(n_samples * 1 / 10)])
テストを実行
# 90%のデータでテスト expected = digits.target[int(n_samples *-9 / 10):] predicted = clf.predict(digits.data[int(n_samples *-9 / 10):]) print(clf,metrics.classification_report(expected, predicted)) print(metrics.confusion_matrix(expected, predicted))
予測結果を可視化
images_and_predictions = list(zip(digits.images[int(n_samples *-9 / 10):], predicted)) for index,(image, prediction) in enumerate(images_and_predictions[:12]): plt.subplot(3, 4, index + 1) plt.axis('off') plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest') plt.title('Prediction: %i' % prediction) plt.show()
このくらいだと1回の学習でもほぼ正確に予測できているように見える。 次はもっと複雑な形を選びたいと思います。
以上