畳み込みニューラルネットワーク(CNN)の復習です。
画像データの配列
まずはデータを準備します。実際は画像データを前処理するところからやりますが、今回は省略します。
高さh、幅w、チャンネル(色)ch、データの個数nとすると、データの形状は(n, ch, h, w)、となるような画像データをランダムで生成します。
import numpy as np
img = np.random.randint(0, 50, (1, 1, 7, 7)) # 任意の範囲の整数の乱数、最小値0、最大値50
# img = np.round(img)
print(img.shape)
print(img[0].shape)
print(img[0])
ランダムで出力される値を確認する
(1, 1, 7, 7)
(1, 7, 7)
[[[ 4 27 8 19 10 29 40]
[40 40 24 15 12 8 13]
[27 41 1 7 1 5 42]
[ 0 37 5 49 30 31 0]
[ 7 48 17 14 43 34 37]
[35 13 8 18 24 45 40]
[16 29 2 6 12 30 10]]]
次に画像データの持ち方を考えます。
import numpy as np
A = np.array(
[[["000", "001", "002", "003"],
["010", "011", "012", "013"],
["020", "021", "022", "023"]],
[["100", "101", "102", "103"],
["110", "111", "112", "113"],
["120", "121", "122", "123"]]]
)
print(A)
print(A.shape)
[[['000' '001' '002' '003']
['010' '011' '012' '013']
['020' '021' '022' '023']]
[['100' '101' '102' '103']
['110' '111' '112' '113']
['120' '121' '122' '123']]]
(2, 3, 4)
im2colによる展開
畳み込み演算やプーリング演算を、for文を重ねなくても実装できるように、入力データを展開処理するために使用される関数。
# 引数は
# 画像データ、カーネル高さ、カーネル幅、ストライド幅、ストライド高さ、パディング高さ、パディング幅
# ストライド量、パディング量は縦横まとめられる場合あり
def im2col(img, k_h, k_w, s_h, s_w, p_h, p_w):
n, c, h, w = img.shape
# print(img.shape)
# パディング処理
img = np.pad(img, [(0,0), (0,0), (p_h, p_h), (p_w, p_w)], 'constant')
# print(img[0])
# print(img.shape)
# 出力データのサイズ計算
out_h = (h + 2*p_h - k_h)//s_h + 1
out_w = (w + 2*p_w - k_w)//s_w + 1
col = np.ndarray((n, c, k_h, k_w, out_h, out_w), dtype=img.dtype) # 戻り値となる4次元配列を準備。(データ数、チャンネル数、カーネル高さ、カーネル幅、出力高さ、出力幅)
# print(col.shape)
# print(col[0])
# フィルターに対応する画素をスライス(colに代入)
for y in range(k_h):
y_lim = y + s_h * out_h # y_lim:最後のフィルターの位置
# print("y_lim")
# print(y_lim)
for x in range(k_w):
x_lim = x + s_w * out_w # y_lim:最後のフィルターの位置
# print("x_lim")
# print(x_lim)
col[:, :, y, x, :, :] = img[:, :, y:y_lim:s_h, x:x_lim:s_w] # colのy番目、x番目に、yからy_limまでをストライド量ごとにスライスしたものを代入
# print("col")
# print(col)
# print("img")
# print(img[:, :, y:y_lim:s_h, x:x_lim:s_w])
# transpose: 多次元配列の軸の順番を入れ替え。reshapeしやすいように、順番を並び替え。(データ数、出力高さ、出力幅、チャンネル数、カーネル高さ、カーネル幅、)
col = col.transpose(0, 4, 5, 1, 2, 3)
# reshape: -1を指定することで、多次元配列の要素数を自動整形。(データ数×出力高さ×出力幅 , チャンネル数×カーネル高さ×カーネル幅)
col = col.reshape(n*out_h*out_w, -1)
return col
畳み込み
畳み込みに必要な関数を用意します。
class Convolution:
def __init__(self, W, b, stride=1, pad=0):
self.W = W # フィルター(カーネル)
self.b = b
self.stride = stride
self.pad = pad
# 中間データ
self.x = None
self.col = None
self.col_W = None
# 重み・バイアスパラメータの勾配
self.dW = None
self.db = None
def forward(self, x):
k_n, c, k_h, k_w = self.W.shape # k_n:フィルターの数
n, c, h, w = x.shape
# 出力データのサイズ計算
out_h = int((h + 2*self.pad - k_h) / self.stride + 1)
out_w = int((w + 2*self.pad - k_w) / self.stride + 1)
# 展開
col = im2col(x, k_h, k_w, self.stride, self.stride, self.pad, self.pad) # 画像を2次元配列化 (データ数×出力高さ×出力幅 , チャンネル数×カーネル高さ×カーネル幅)
col_W = self.W.reshape(k_n, -1).T # フィルターを2次元配列化
out = np.dot(col, col_W) + self.b #行列積(畳み込み演算)
# 整形
out = out.reshape(n, out_h, out_w, -1).transpose(0, 3, 1, 2) # 2次元配列→4次元配列
return out
プーリング
画像をMax Poolingしていくための関数を用意します。
class Pooling:
def __init__(self, pool_h, pool_w, stride=1, pad=0):
self.pool_h = pool_h
self.pool_w = pool_w
self.stride = stride
self.pad = pad
self.x = None
self.arg_max = None
def forward(self, x):
N, C, H, W = x.shape
out_h = int(1 + (H - self.pool_h) / self.stride)
out_w = int(1 + (W - self.pool_w) / self.stride)
col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
col = col.reshape(-1, self.pool_h*self.pool_w)
arg_max = np.argmax(col, axis=1)
out = np.max(col, axis=1)
out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)
self.x = x
self.arg_max = arg_max
return out
def backward(self, dout):
dout = dout.transpose(0, 2, 3, 1)
pool_size = self.pool_h * self.pool_w
dmax = np.zeros((dout.size, pool_size))
#flattenは構造を1次元配列に入れ直すこと
dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
dmax = dmax.reshape(dout.shape + (pool_size,))
dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
return dx
この後CNN実装と画像可視化していきますが、後日追加したいと思います。