Pertanyaan Tensorflow: Cara mendapatkan semua variabel dari rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell


Saya memiliki pengaturan di mana saya harus menginisialisasi LSTM setelah inisialisasi utama yang menggunakan tf.initialize_all_variables(). Yaitu. Saya ingin menelepon tf.initialize_variables([var_list]) 

Apakah ada cara untuk mengumpulkan semua variabel internal yang dapat dilatih untuk keduanya:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

sehingga saya dapat menginisialisasi HANYA parameter ini?

Alasan utama saya menginginkan ini adalah karena saya tidak ingin menginisialisasi ulang beberapa nilai terlatih dari sebelumnya.


17
2018-01-26 11:43


asal


Jawaban:


Cara termudah untuk menyelesaikan masalah Anda adalah menggunakan lingkup variabel. Nama-nama variabel dalam lingkup akan diawali dengan namanya. Berikut ini cuplikan singkat:

cell = rnn_cell.BasicLSTMCell(num_nodes)

with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  # Retrieve just the LSTM variables.
  lstm_variables = [v for v in tf.all_variables()
                    if v.name.startswith(vs.name)]

# [..]
# Initialize the LSTM variables.
tf.initialize_variables(lstm_variables)

Itu akan bekerja dengan cara yang sama MultiRNNCell.

EDIT: berubah tf.trainable_variables untuk tf.all_variables()


17
2018-01-26 15:58



Anda juga bisa menggunakan tf.get_collection():

cell = rnn_cell.BasicLSTMCell(num_nodes)
with tf.variable_scope("LSTM") as vs:
  # Execute the LSTM cell here in any way, for example:
  for i in range(num_steps):
    output[i], state = cell(input_data[i], state)

  lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name)

(sebagian disalin dari jawaban Rafal)

Perhatikan bahwa baris terakhir setara dengan pemahaman daftar dalam kode Rafal.

Pada dasarnya, tensorflow menyimpan kumpulan variabel global, yang dapat diambil oleh keduanya tf.all_variables() atau tf.get_collection(tf.GraphKeys.VARIABLES). Jika Anda tentukan scope (nama ruang lingkup) di tf.get_collection() fungsi, maka Anda hanya mengambil tensor (variabel dalam kasus ini) dalam koleksi yang cakupannya berada di bawah lingkup yang ditentukan.

EDIT: Anda juga bisa menggunakan tf.GraphKeys.TRAINABLE_VARIABLES untuk mendapatkan variabel yang dapat dilatih saja. Tetapi karena vanilla BasicLSTMCell tidak menginisialisasi variabel yang tidak dapat dilatih, keduanya akan secara fungsional setara. Untuk daftar lengkap koleksi grafik default, periksa ini di luar.


11
2018-04-29 06:32