Pertanyaan Thread paralel dengan TensorFlow Dataset API dan flat_map


Saya mengubah kode TensorFlow saya dari antarmuka antrian lama ke yang baru API Dataset. Dengan antarmuka lama saya bisa menentukan num_threads argumen untuk tf.train.shuffle_batch antre. Namun, satu-satunya cara untuk mengontrol jumlah untaian di API Dataset tampaknya berada di map fungsi menggunakan num_parallel_calls argumen. Namun, saya menggunakan flat_map berfungsi sebagai gantinya, yang tidak memiliki argumen seperti itu.

Pertanyaan: Apakah ada cara untuk mengontrol jumlah utas / proses untuk flat_map fungsi? Atau apakah ada cara untuk menggunakannya map dalam kombinasi dengan flat_map dan masih menentukan jumlah panggilan paralel?

Perhatikan bahwa sangat penting untuk menjalankan beberapa utas secara paralel, karena saya berniat menjalankan pra-pemrosesan berat pada CPU sebelum data memasuki antrean.

Ada dua (sini dan sini) posting terkait di GitHub, tapi saya tidak berpikir mereka menjawab pertanyaan ini.

Berikut ini contoh kode minimal dari use case saya untuk ilustrasi:

with tf.Graph().as_default():
    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    def pre_processing_func(data_):
        # normally I would do data-augmentation here
        results = (tf.expand_dims(data_, axis=0),)
        return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    # do something with 'dataset'

13
2017-11-21 10:56


asal


Jawaban:


Sejauh pengetahuan saya, saat ini flat_map tidak menawarkan opsi paralelisme. Mengingat sebagian besar perhitungan dilakukan di pre_processing_func, apa yang mungkin Anda gunakan sebagai solusi adalah paralel map panggilan diikuti oleh beberapa buffering, dan kemudian menggunakan flat_map panggilan dengan fungsi lambda identitas yang merawat meratakan output.

Dalam kode:

NUM_THREADS = 5
BUFFER_SIZE = 1000

def pre_processing_func(data_):
    # data-augmentation here
    # generate new samples starting from the sample `data_`
    artificial_samples = generate_from_sample(data_)
    return atificial_samples

dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
                  map(pre_processing_func, num_parallel_calls=NUM_THREADS).
                  prefetch(BUFFER_SIZE).
                  flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
                  shuffle(BUFFER_SIZE)) # my addition, probably necessary though

Catatan (untuk saya dan siapa pun yang akan mencoba memahami jalur pipa):

Sejak pre_processing_func menghasilkan sejumlah sampel baru secara acak mulai dari sampel awal (disusun dalam matriks bentuk (?, 512)), yang flat_map panggilan diperlukan untuk mengubah semua matriks yang dihasilkan menjadi Datasets mengandung sampel tunggal (maka itu tf.data.Dataset.from_tensor_slices(x) di lambda) dan kemudian meratakan semua dataset ini menjadi satu besar Dataset mengandung sampel individu.

Mungkin itu ide yang bagus .shuffle() dataset itu, atau sampel yang dihasilkan akan dikemas bersama.


8
2017-11-21 13:18