畳み込みニューラルネットワーク(CNN)の復習です。
画像データの配列
まずはデータを準備します。実際は画像データを前処理するところからやりますが、今回は省略します。 高さh、幅w、チャンネル(色)ch、データの個数nとすると、データの形状は(n, ch, h, w)、となるような画像データをランダムで生成します。
import numpy as np img = np.random.randint(0, 50, (1, 1, 7, 7)) # 任意の範囲の整数の乱数、最小値0、最大値50 # img = np.round(img) print(img.shape) print(img[0].shape) print(img[0])
ランダムで出力される値を確認する
(1, 1, 7, 7) (1, 7, 7) [[[ 4 27 8 19 10 29 40] [40 40 24 15 12 8 13] [27 41 1 7 1 5 42] [ 0 37 5 49 30 31 0] [ 7 48 17 14 43 34 37] [35 13 8 18 24 45 40] [16 29 2 6 12 30 10]]]
次に画像データの持ち方を考えます。
import numpy as np A = np.array( [[["000", "001", "002", "003"], ["010", "011", "012", "013"], ["020", "021", "022", "023"]], [["100", "101", "102", "103"], ["110", "111", "112", "113"], ["120", "121", "122", "123"]]] ) print(A) print(A.shape)
[[['000' '001' '002' '003'] ['010' '011' '012' '013'] ['020' '021' '022' '023']] [['100' '101' '102' '103'] ['110' '111' '112' '113'] ['120' '121' '122' '123']]] (2, 3, 4)
im2colによる展開
畳み込み演算やプーリング演算を、for文を重ねなくても実装できるように、入力データを展開処理するために使用される関数。
# 引数は # 画像データ、カーネル高さ、カーネル幅、ストライド幅、ストライド高さ、パディング高さ、パディング幅 # ストライド量、パディング量は縦横まとめられる場合あり def im2col(img, k_h, k_w, s_h, s_w, p_h, p_w): n, c, h, w = img.shape # print(img.shape) # パディング処理 img = np.pad(img, [(0,0), (0,0), (p_h, p_h), (p_w, p_w)], 'constant') # print(img[0]) # print(img.shape) # 出力データのサイズ計算 out_h = (h + 2*p_h - k_h)//s_h + 1 out_w = (w + 2*p_w - k_w)//s_w + 1 col = np.ndarray((n, c, k_h, k_w, out_h, out_w), dtype=img.dtype) # 戻り値となる4次元配列を準備。(データ数、チャンネル数、カーネル高さ、カーネル幅、出力高さ、出力幅) # print(col.shape) # print(col[0]) # フィルターに対応する画素をスライス(colに代入) for y in range(k_h): y_lim = y + s_h * out_h # y_lim:最後のフィルターの位置 # print("y_lim") # print(y_lim) for x in range(k_w): x_lim = x + s_w * out_w # y_lim:最後のフィルターの位置 # print("x_lim") # print(x_lim) col[:, :, y, x, :, :] = img[:, :, y:y_lim:s_h, x:x_lim:s_w] # colのy番目、x番目に、yからy_limまでをストライド量ごとにスライスしたものを代入 # print("col") # print(col) # print("img") # print(img[:, :, y:y_lim:s_h, x:x_lim:s_w]) # transpose: 多次元配列の軸の順番を入れ替え。reshapeしやすいように、順番を並び替え。(データ数、出力高さ、出力幅、チャンネル数、カーネル高さ、カーネル幅、) col = col.transpose(0, 4, 5, 1, 2, 3) # reshape: -1を指定することで、多次元配列の要素数を自動整形。(データ数×出力高さ×出力幅 , チャンネル数×カーネル高さ×カーネル幅) col = col.reshape(n*out_h*out_w, -1) return col
畳み込み
畳み込みに必要な関数を用意します。
class Convolution: def __init__(self, W, b, stride=1, pad=0): self.W = W # フィルター(カーネル) self.b = b self.stride = stride self.pad = pad # 中間データ self.x = None self.col = None self.col_W = None # 重み・バイアスパラメータの勾配 self.dW = None self.db = None def forward(self, x): k_n, c, k_h, k_w = self.W.shape # k_n:フィルターの数 n, c, h, w = x.shape # 出力データのサイズ計算 out_h = int((h + 2*self.pad - k_h) / self.stride + 1) out_w = int((w + 2*self.pad - k_w) / self.stride + 1) # 展開 col = im2col(x, k_h, k_w, self.stride, self.stride, self.pad, self.pad) # 画像を2次元配列化 (データ数×出力高さ×出力幅 , チャンネル数×カーネル高さ×カーネル幅) col_W = self.W.reshape(k_n, -1).T # フィルターを2次元配列化 out = np.dot(col, col_W) + self.b #行列積(畳み込み演算) # 整形 out = out.reshape(n, out_h, out_w, -1).transpose(0, 3, 1, 2) # 2次元配列→4次元配列 return out
プーリング
画像をMax Poolingしていくための関数を用意します。
class Pooling: def __init__(self, pool_h, pool_w, stride=1, pad=0): self.pool_h = pool_h self.pool_w = pool_w self.stride = stride self.pad = pad self.x = None self.arg_max = None def forward(self, x): N, C, H, W = x.shape out_h = int(1 + (H - self.pool_h) / self.stride) out_w = int(1 + (W - self.pool_w) / self.stride) col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad) col = col.reshape(-1, self.pool_h*self.pool_w) arg_max = np.argmax(col, axis=1) out = np.max(col, axis=1) out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2) self.x = x self.arg_max = arg_max return out def backward(self, dout): dout = dout.transpose(0, 2, 3, 1) pool_size = self.pool_h * self.pool_w dmax = np.zeros((dout.size, pool_size)) #flattenは構造を1次元配列に入れ直すこと dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten() dmax = dmax.reshape(dout.shape + (pool_size,)) dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1) dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad) return dx
この後CNN実装と画像可視化していきますが、後日追加したいと思います。