hiroportation

ITの話だったり、音楽の話、便利なガジェットの話題などを発信しています

CNNをゼロから実装

畳み込みニューラルネットワーク(CNN)の復習です。

画像データの配列

まずはデータを準備します。実際は画像データを前処理するところからやりますが、今回は省略します。 高さh、幅w、チャンネル(色)ch、データの個数nとすると、データの形状は(n, ch, h, w)、となるような画像データをランダムで生成します。

import numpy as np

img = np.random.randint(0, 50, (1, 1, 7, 7)) # 任意の範囲の整数の乱数、最小値0、最大値50
# img = np.round(img)
print(img.shape)

print(img[0].shape)
print(img[0])

ランダムで出力される値を確認する

(1, 1, 7, 7)
(1, 7, 7)
[[[ 4 27  8 19 10 29 40]
  [40 40 24 15 12  8 13]
  [27 41  1  7  1  5 42]
  [ 0 37  5 49 30 31  0]
  [ 7 48 17 14 43 34 37]
  [35 13  8 18 24 45 40]
  [16 29  2  6 12 30 10]]]

次に画像データの持ち方を考えます。

import numpy as np

A = np.array(
    [[["000", "001", "002", "003"],
      ["010", "011", "012", "013"],
      ["020", "021", "022", "023"]],
     [["100", "101", "102", "103"],
      ["110", "111", "112", "113"],
      ["120", "121", "122", "123"]]]
)
print(A)

print(A.shape)
[[['000' '001' '002' '003']
  ['010' '011' '012' '013']
  ['020' '021' '022' '023']]

 [['100' '101' '102' '103']
  ['110' '111' '112' '113']
  ['120' '121' '122' '123']]]
(2, 3, 4)

im2colによる展開

畳み込み演算やプーリング演算を、for文を重ねなくても実装できるように、入力データを展開処理するために使用される関数。

# 引数は
# 画像データ、カーネル高さ、カーネル幅、ストライド幅、ストライド高さ、パディング高さ、パディング幅
# ストライド量、パディング量は縦横まとめられる場合あり
def im2col(img, k_h, k_w, s_h, s_w, p_h, p_w):
    n, c, h, w = img.shape 
    # print(img.shape)

    # パディング処理
    img = np.pad(img, [(0,0), (0,0), (p_h, p_h), (p_w, p_w)], 'constant') 
    # print(img[0])
    # print(img.shape)

    # 出力データのサイズ計算
    out_h = (h + 2*p_h - k_h)//s_h + 1
    out_w = (w + 2*p_w - k_w)//s_w + 1

    col = np.ndarray((n, c, k_h, k_w, out_h, out_w), dtype=img.dtype)  # 戻り値となる4次元配列を準備。(データ数、チャンネル数、カーネル高さ、カーネル幅、出力高さ、出力幅)
    # print(col.shape)
    # print(col[0])
    
    # フィルターに対応する画素をスライス(colに代入)
    for y in range(k_h):
        y_lim = y + s_h * out_h # y_lim:最後のフィルターの位置
        # print("y_lim")
        # print(y_lim)
        for x in range(k_w):
            x_lim = x + s_w * out_w # y_lim:最後のフィルターの位置
            # print("x_lim")
            # print(x_lim)
            col[:, :, y, x, :, :] = img[:, :, y:y_lim:s_h, x:x_lim:s_w] # colのy番目、x番目に、yからy_limまでをストライド量ごとにスライスしたものを代入
            # print("col")
            # print(col)
            # print("img")
            # print(img[:, :, y:y_lim:s_h, x:x_lim:s_w])
    
    # transpose: 多次元配列の軸の順番を入れ替え。reshapeしやすいように、順番を並び替え。(データ数、出力高さ、出力幅、チャンネル数、カーネル高さ、カーネル幅、)
    col = col.transpose(0, 4, 5, 1, 2, 3)

    # reshape: -1を指定することで、多次元配列の要素数を自動整形。(データ数×出力高さ×出力幅 , チャンネル数×カーネル高さ×カーネル幅)
    col = col.reshape(n*out_h*out_w, -1) 
    return col

畳み込み

畳み込みに必要な関数を用意します。

class Convolution:
    def __init__(self, W, b, stride=1, pad=0):
        self.W = W # フィルター(カーネル)
        self.b = b
        self.stride = stride
        self.pad = pad

        # 中間データ
        self.x = None   
        self.col = None
        self.col_W = None

        # 重み・バイアスパラメータの勾配
        self.dW = None
        self.db = None

    def forward(self, x):
        k_n, c, k_h, k_w = self.W.shape # k_n:フィルターの数
        n, c, h, w = x.shape
        
        # 出力データのサイズ計算
        out_h = int((h + 2*self.pad - k_h) / self.stride + 1)
        out_w = int((w + 2*self.pad - k_w) / self.stride + 1)
        
        # 展開
        col = im2col(x, k_h, k_w, self.stride, self.stride, self.pad, self.pad) # 画像を2次元配列化 (データ数×出力高さ×出力幅 , チャンネル数×カーネル高さ×カーネル幅)
        col_W = self.W.reshape(k_n, -1).T # フィルターを2次元配列化
        out = np.dot(col, col_W) + self.b #行列積(畳み込み演算)

        # 整形
        out = out.reshape(n, out_h, out_w, -1).transpose(0, 3, 1, 2) # 2次元配列→4次元配列

        return out

プーリング

画像をMax Poolingしていくための関数を用意します。

class Pooling:
    def __init__(self, pool_h, pool_w, stride=1, pad=0):
        self.pool_h = pool_h
        self.pool_w = pool_w
        self.stride = stride
        self.pad = pad

        self.x = None
        self.arg_max = None

    def forward(self, x):
        N, C, H, W = x.shape
        out_h = int(1 + (H - self.pool_h) / self.stride)
        out_w = int(1 + (W - self.pool_w) / self.stride)

        col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
        col = col.reshape(-1, self.pool_h*self.pool_w)

        arg_max = np.argmax(col, axis=1)
        out = np.max(col, axis=1)
        out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)

        self.x = x
        self.arg_max = arg_max

        return out

    def backward(self, dout):
        dout = dout.transpose(0, 2, 3, 1)

        pool_size = self.pool_h * self.pool_w
        dmax = np.zeros((dout.size, pool_size))
        #flattenは構造を1次元配列に入れ直すこと
        dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
        dmax = dmax.reshape(dout.shape + (pool_size,)) 

        dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
        dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)

        return dx

この後CNN実装と画像可視化していきますが、後日追加したいと思います。