Pocket
LINEで送る

参考:TensorFlow : How To : データを読む

ファイルからデータを読む場合の処理の流れ

・ファイルパスのリストを、tf.train.string_input_producer 関数に渡す(shuffle=Trueにしたらepochでファイル名をシャッフルする)
・読み込むデータに合わせてreaderを選択する。
・ファイル名キューをreaderのreadメソッドに渡す。readメソッドは、ファイルとレコード識別キーと、スカラ文字列値を返す。
・スカラ文字列を、サンプルを構成するテンソルに変換するためにデコーダと変換 OPs の一つ(あるいはそれ以上)を利用します。

csvファイルを読んでみる

import tensorflow as tf

#ファイルパスのリストをtf.train.string_input_producerに渡す
filename_queue = tf.train.string_input_producer(["./hoge/file0.csv", "./hoge/file1.csv"])

#カンマ区切りCSVファイルは、TextLinerReaderを使う
#TextLinerReaderクラスを呼び出す
reader = tf.TextLineReader()
#readerクラスのreadメソッドにファイル名キューを渡す
#ファイルとレコードの識別キー, スカラ文字列値
key, value = reader.read(filename_queue)

# 空カラムの場合の、デフォルト値。
record_defaults = [[1], [1], [1], [1], [1]]
# decode_csvでファイルの1行の実際のカラム値を取得できる
col1, col2, col3, col4, col5 = tf.decode_csv(
    value, record_defaults=record_defaults)
#col1-col4までが特徴データ、col5がラベルという想定らしい
features = tf.pack([col1, col2, col3, col4])

with tf.Session() as sess:
    # ファイル名キューへのデータ取り込みを開始する。 (Start populating the filename queue.)
    #readを実行するためにrunまたはevalを呼び出す前に、キューのデータを取り込むためには
    #tf.train.start_queue_runnersを呼び出さなければなりません。
    #そうでないとキューからのファイル名を待っている間、readはブロックします。
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(1200):
        # 一つのインスタンスを取得する。
        example, label = sess.run([features, col5])

    #runまたはevalが終わったら書く必要があるっぽい。
    coord.request_stop()
    coord.join(threads)

上記コードの動きを色々なパターンでチェックしてみまっす

ちなみに、csvファイルは、下記です。

file0.csv

100,101,102,103,1
200,201,202,203,2
300,301,302,303,3
400,401,402,403,4
500,501,502,503,5
600,601,602,603,6
700,701,702,703,7
800,801,802,803,8

file1.csv

100,101,102,103,11
200,201,202,203,12
300,301,302,303,13
400,401,402,403,14
500,501,502,503,15
600,601,602,603,16
700,701,702,703,17
800,801,802,803,18

下記を試してみます。コードは基本上記と同じですが、exmapleとlabelを3回表示させています。結果は、file0.csvとfile1.csvのどちらかの1~3行目を順番に取得していることが分かりました。ファイルから取り出すのは1行目から順番ですが、取り出すファイル自体は勝手にランダムな感じにしてくれてるっぽいです。

import tensorflow as tf

filename_queue = tf.train.string_input_producer(["./hoge/file0.csv", "./hoge/file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.pack([col1, col2, col3, col4])

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(3):
        example, label = sess.run([features, col5])
        print('Step: {}'.format(i))
        print(example)
        print(label)
        print('------------------')

    coord.request_stop()
    coord.join(threads)

結果(下記のときと、file1.csvの内容の時がランダムに変わる)

Step: 0
[100 101 102 103]
1
------------------
Step: 1
[200 201 202 203]
2
------------------
Step: 2
[300 301 302 303]
3
------------------

epochでファイルが変わるか確認する

tf.train.string_input_producerの引数にshuffle=Trueを渡すと、epochでファイル名をシャッフルすると書いてありますので、それも試してみようと思います。まずは、shuffle=Trueを設定しない場合。

import tensorflow as tf

filename_queue = tf.train.string_input_producer(["./hoge/file0.csv", "./hoge/file1.csv"])
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
record_defaults = [[1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5 = tf.decode_csv(value, record_defaults=record_defaults)
features = tf.pack([col1, col2, col3, col4])

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(20):
        example, label = sess.run([features, col5])
        print('Step: {}, Label: {}'.format(i, label))

    coord.request_stop()
    coord.join(threads)

結果
下記のようになりました。file0と1だけだと、シャッフルしてるか分かりづらいですね。。とりあえずこれはシャッフルはしてないです。何回かやってみます。

Step: 0, Label: 11
Step: 1, Label: 12
Step: 2, Label: 13
Step: 3, Label: 14
Step: 4, Label: 15
Step: 5, Label: 16
Step: 6, Label: 17
Step: 7, Label: 18
Step: 8, Label: 1
Step: 9, Label: 2
Step: 10, Label: 3
Step: 11, Label: 4
Step: 12, Label: 5
Step: 13, Label: 6
Step: 14, Label: 7
Step: 15, Label: 8
Step: 16, Label: 11
Step: 17, Label: 12
Step: 18, Label: 13
Step: 19, Label: 14

同じコードをもう一度実行したら下記になりました。これぞまさにシャッフルですね。ということは、デフォルトでshuffle=Trueが設定されているようです。

Step: 0, Label: 1
Step: 1, Label: 2
Step: 2, Label: 3
Step: 3, Label: 4
Step: 4, Label: 5
Step: 5, Label: 6
Step: 6, Label: 7
Step: 7, Label: 8
Step: 8, Label: 11
Step: 9, Label: 12
Step: 10, Label: 13
Step: 11, Label: 14
Step: 12, Label: 15
Step: 13, Label: 16
Step: 14, Label: 17
Step: 15, Label: 18
Step: 16, Label: 11
Step: 17, Label: 12
Step: 18, Label: 13
Step: 19, Label: 14

それでは、shuffle=False設定にしてみます。
コードで変更するのは下記だけです。

filename_queue = tf.train.string_input_producer(["./hoge/file0.csv", "./hoge/file1.csv"], shuffle=False)

結果
何回やっても下記になりました。shffle=Falseにすると、最初に取得するファイルも、tf.train.string_input_producerに渡した順番通りで固定のようです。

Step: 0, Label: 1
Step: 1, Label: 2
Step: 2, Label: 3
Step: 3, Label: 4
Step: 4, Label: 5
Step: 5, Label: 6
Step: 6, Label: 7
Step: 7, Label: 8
Step: 8, Label: 11
Step: 9, Label: 12
Step: 10, Label: 13
Step: 11, Label: 14
Step: 12, Label: 15
Step: 13, Label: 16
Step: 14, Label: 17
Step: 15, Label: 18
Step: 16, Label: 1
Step: 17, Label: 2
Step: 18, Label: 3
Step: 19, Label: 4

何をしたらエラーになるか確認してみる

上記のコードで、csvファイルの1行当たりのカラム数が5より大きい場合はエラーになりました。

tensorflow.python.framework.errors_impl.InvalidArgumentError: Expect 5 fields but have 6 in record 0

カラム数が5より小さくてもエラーになりました。

tensorflow.python.framework.errors_impl.InvalidArgumentError: Expect 5 fields but have 4 in record 0

record_defaultsを下記のように変えたらエラーになりました。record_defaultsの形状をみて、想定しているデータの個数やデータ型を想定しているようです。それと合わないデータが入ってたらエラーを出すようです。

record_defaults = [[1.0], [1], [1], [1], [1]]

エラー

ValueError: Tensor conversion requested dtype float32 for Tensor with dtype int32: 'Tensor("DecodeCSV:1", shape=(), dtype=int32)'

record_defaultsを下記のように変えたらエラーになりました。shapeのランクが1じゃないといけないらしいです。shapeが(5, 0)になるからダメで、(5, 1)になるようにしないといけないってことかな?

record_defaults = [1, 1, 1, 1, 1]

エラー

tensorflow.python.framework.errors_impl.InvalidArgumentError: Shape must be rank 1 but is rank 0 for 'DecodeCSV' (op: 'DecodeCSV') with input shapes: [], [], [], [], [], [].

record_defaultsを下記にしたら、エラーになりました。numpyの配列だとダメっぽいです。

record_defaults = np.zeros([5, 1], dtype=np.int)

エラー

TypeError: Expected list for 'record_defaults' argument to 'DecodeCSV' Op, not [[0] [0] [0] [0] [0]].

これだとOKでした。

record_defaults = np.zeros([5, 1], dtype=np.int).tolist()

固定長バイナリデータファイルを読み込む

各レコードがバイトの固定数である、バイナリ・ファイルを読むためには、 tf.decode_raw 演算とともに tf.FixedLengthRecordReader を使用します。decode_raw 演算は文字列から uint8 テンソルに変換します。

例えば、CIFAR-10 データセット は、各レコードがバイトの固定長を使用して表される、ファイルフォーマットを使用します: ラベルのための1 バイト、続いて 3072 バイトの画像データです。ひとたび uint8 テンソルを持てば、標準操作で各部分をスライスし必要に応じて再フォーマットすることができます。CIFAR-10 については、どのように読みデコードするかを tensorflow/models/image/cifar10/cifar10_input.py で見ることができ、このチュートリアル で説明されています。

コードサンプル

import tensorflow as tf
import os

data_dir = './hoge/cifar-10-batches-bin'
label_bytes = 1
height = 32
width = 32
depth = 3
image_bytes = height * width * depth
record_bytes = label_bytes + image_bytes

filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
             for i in range(1, 6)]
for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)

filename_queue = tf.train.string_input_producer(filenames)
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
key, value = reader.read(filename_queue)
data = tf.decode_raw(value, tf.uint8)
label = tf.cast(tf.strided_slice(data, [0], [label_bytes], [1]), tf.int32)
img = tf.reshape(
    tf.strided_slice(data, [label_bytes], [record_bytes], [1]),
    [depth, height, width])
uint8image = tf.transpose(img, [1, 2, 0])

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(1):
        uint8image, label = sess.run([uint8image, label])
        print(label.shape)
        print(uint8image.shape)

    coord.request_stop()
    coord.join(threads)

結果

(1,)
(32, 32, 3)

これで、cifar10のデータを読み込めました。

Pocket
LINEで送る


コメントください

関連記事

プログラミング

Go言語によるビットコインのフルノード実装btcdを調べる(2)

btcdを実行した際のプログラムの流れを最初から確認してみます。 se 続きを読む …

プログラミング

Go – leveldb

Goで使えるLevelDB。 syndtr/goleveldb ドキュ 続きを読む …

%d人のブロガーが「いいね」をつけました。