TensorFlowでオートエンコーダー

最近はTensorFlowの練習として、公式のサンプルの写経をしています。自分なりにコードを解釈してコメントをつけてみたので、公開します。

ソースコード全体

import tensorflow as tf
import numpy as np 
import matplotlib.pyplot as plt 
from tensorflow.examples.tutorials.mnist import input_data

# 画像を保存する用
import os
from PIL import Image

# 学習データを読み込む
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

# 訓練用のパラメーター設定
learning_rate = 0.01
num_steps = 30000
batch_size = 256

# 結果を表示する際のパラメータ
display_step = 1000
examples_to_show = 10

# ニューラルネットのパラメータ
num_hidden_1 = 256
num_hidden_2 = 128
num_input = 784


# 重みの設定
weights = {
    "encoder_h1": tf.Variable(tf.random_normal([num_input, num_hidden_1])),
    "encoder_h2": tf.Variable(tf.random_normal([num_hidden_1, num_hidden_2])),
    "decoder_h1": tf.Variable(tf.random_normal([num_hidden_2, num_hidden_1])),
    "decoder_h2": tf.Variable(tf.random_normal([num_hidden_1, num_input])),
}

# バイアスの設定
biases = {
    "encoder_b1": tf.Variable(tf.random_normal([num_hidden_1])),
    "encoder_b2": tf.Variable(tf.random_normal([num_hidden_2])),
    "decoder_b1": tf.Variable(tf.random_normal([num_hidden_1])),
    "decoder_b2": tf.Variable(tf.random_normal([num_input])),
}

# ネットワークのうちエンコード部分
def encoder(x):
    layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']), biases['encoder_b1']))
    layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']), biases['encoder_b2']))
    return layer_2

# ネットワークのうちのデコード部分
def decoder(x):
    layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']), biases['decoder_b1']))
    layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']), biases['decoder_b2']))
    return layer_2

if __name__ == "__main__":
    # 入力変数
    X = tf.placeholder("float", [None, num_input])

    # エンコード部の生成
    encoder_op = encoder(X)
    # デコード部の生成
    decoder_op = decoder(encoder_op)

    # 画像のクラスを予測結果を出力
    y_pred = decoder_op
    # 教師データは入力画像と同じ
    y_true = X

    # 損失関数は最小二乗誤差
    loss = tf.reduce_mean(tf.pow(y_true - y_pred, 2))
    # 最適化手法はRMSProp
    optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(loss)

    # 変数の初期化
    init = tf.global_variables_initializer()

    # 学習の処理
    with tf.Session() as sess:
        sess.run(init)

        # 学習のループ
        for i in range(1, num_steps+1):
            # 今回のバッチで使用する画像を読み込む 正解ラベルは使わない
            batch_x, _ = mnist.train.next_batch(batch_size)
            # 学習を行う
            _, l = sess.run([optimizer, loss], feed_dict={X: batch_x})

            # たまにロスの値をプリントする
            if i % display_step == 0 or i == 1:
                print("Step %i: Loss: %f" % (i, l))

        # テスト
        # 入力画像を正確に出力できているか確認する
        n = 4
        canvas_orig = np.empty((28 * n, 28 * n))
        canvas_recon = np.empty((28 * n, 28 * n))
        for i in range(n):
            batch_x, _ = mnist.test.next_batch(n)
            # エンコード、でコードを行なった結果の画像を出力
            g = sess.run(decoder_op, feed_dict={X: batch_x})

            # 入力の画像を表示
            for j in range(n):
                batch_image = batch_x[j].reshape([28, 28])
                canvas_orig[i * 28: (i + 1) * 28, j * 28: (j + 1) * 28] = batch_image
                
            # 出力の画像を表示
            for j in range(n):
                batch_image = g[j].reshape([28, 28])
                canvas_recon[i * 28: (i + 1) * 28, j * 28: (j + 1) * 28] = batch_image  

            print("Original Images")
            plt.figure(figsize=(n, n))
            plt.imshow(canvas_orig, origin="upper", cmap="gray")
            plt.show()

            print("Reconstructed Images")
            plt.figure(figsize=(n, n))
            plt.imshow(canvas_recon, origin="upper", cmap="gray")
            plt.show()

解説

ほとんど公式のサンプルそのままなのですが…

一応解説させてもらいます。オートエンコーダーの大まかな流れは

  1. 画像をニューラルネットに入力
  2. 層のノード数を減らしていく
  3. 入力画像の次元数よりノード数が少なくなるまで減らす
  4. 今度はノード数を増やす
  5. 出力の次元数は入力画像の次元数と同じ
  6. 入力と出力がどれだけ近いか計算

のような感じです。このブログの画像を見ると分かりやすいかと思います。オートエンコーダーの解説も丁寧でしたので、一度見ておくといいんじゃないんでしょうか。

結果

今回ニューラルネットに入力した画像はMNISTで、次のようなものです。

そして、学習を終えて出力された画像はこのようなものでした。

結構ノイズが残っていますが、数字の特徴をなんとなく捉えられているんじゃないかと思います。

こうすることで、入力された画像の特徴を学習したニューラルネットを作ることができます。このニューラルネットを分類問題の学習をする際の初期値に使ったり、画像検索をする際の特徴計算に使ったりできるわけです。

実装はそんなに難しくありませんが、広い応用範囲を持った技術なんじゃないかと思います。

シェアする

  • このエントリーをはてなブックマークに追加

フォローする