Pocket
LINEで送る

MNISTを取得する

TensorFlowのチュートリアル用にMNISTが簡単に取り込めるようになっているらしい。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('.', one_hot=True)

上記を実行するとまだMNISTをダウンロードしてない場合はダウンロードしてくれる。

batch_x, batch_t = mnist.train.next_batch(100)

上記を実行すると、100件分のトレーニング用データとトレーニング用正解ラベルをもらえる。numpyの配列としてもらえる。input_data.read_data_setsの引数で、one_hot=Trueにすると、正解ラベルも正解を1、違うものを0とした配列でもらえる。

やってみること

  • MNISTのトレーニングデータで学習して、テストデータでテストする。
  • トレーニングデータは100件ずつ使って学習する。
  • 最も単純な入力層と出力層しかないものを試してみる。

コード

参考:jupyter_tfbook/Chapter02/MNIST softmax estimation.ipynb

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('.', one_hot=True)

x = tf.placeholder(tf.float32, [None, 784])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
f = tf.matmul(x, w) + b
y = tf.nn.softmax(f)
t = tf.placeholder(tf.float32, [None, 10])
loss = -tf.reduce_sum(t * tf.log(y))
train_step = tf.train.AdamOptimizer().minimize(loss)

correct = tf.equal(tf.argmax(y, 1), tf.argmax(t, 1))
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())

i = 0
for _ in range(1500):
    i += 1
    batch_x, batch_t = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_x, t: batch_t})
    if i % 100 == 0:
        loss_val, acc_val = sess.run([loss, accuracy], feed_dict={x: mnist.test.images, t: mnist.test.labels})
        print ('Step: %d, Loss: %f, Accuracy: %f' % (i, loss_val, acc_val))

Step: 100, Loss: 7747.071777, Accuracy: 0.848400
Step: 200, Loss: 5439.357910, Accuracy: 0.879900
Step: 300, Loss: 4556.465332, Accuracy: 0.890900
Step: 400, Loss: 4132.032715, Accuracy: 0.896100
Step: 500, Loss: 3836.136963, Accuracy: 0.902600
Step: 600, Loss: 3657.867920, Accuracy: 0.904100
Step: 700, Loss: 3498.280762, Accuracy: 0.907500
Step: 800, Loss: 3376.391602, Accuracy: 0.909400
Step: 900, Loss: 3292.480713, Accuracy: 0.910000
Step: 1000, Loss: 3207.918213, Accuracy: 0.912700
Step: 1100, Loss: 3147.843018, Accuracy: 0.914700
Step: 1200, Loss: 3092.903320, Accuracy: 0.916300
Step: 1300, Loss: 3057.232666, Accuracy: 0.915900
Step: 1400, Loss: 3010.664307, Accuracy: 0.916200
Step: 1500, Loss: 2972.114746, Accuracy: 0.917300

Pocket
LINEで送る


コメントください

関連記事

プログラミング

Go言語によるビットコインのフルノード実装btcdを調べる(2)

btcdを実行した際のプログラムの流れを最初から確認してみます。 se 続きを読む …

プログラミング

Go – leveldb

Goで使えるLevelDB。 syndtr/goleveldb ドキュ 続きを読む …

%d人のブロガーが「いいね」をつけました。