Hataraita.net

馬鹿なこととITと馬鹿なことと競馬と変な資格が好きです。でも馬鹿なことはもっと好きです。

Tensorflow1.0~1.1でSeq2Seqチュートリアル(RNNCell)を使う上での注意点

f:id:bananawoma:20170111122413j:plain

TF1.xでのSeq2Seqチュートリアルの挙動

最近tensorflow1.0でSeq2seqチュートリアルをカスタマイズする機会があったのですが、なかなかハマったので忘備的に記録を残しておこうと思います。

結構、同じようにハマっている方がいらっしゃったみたいですが外国の方が多かったようなので日本語で。

TF0.12からTF1.0への変更点については、日本語化されたドキュメントも多数ありますのでそちらに譲るとして、Seq2Seqチュートリアルをカスタマイズするうえで記載されていない変更点があり、それに苦労しました。

どうもTF1.2ではその対策が取られているみたいですが、詳しく読んでいないので割愛します。

RNNCellの仕様変更

その変更点というのは、RNNCellの仕様変更という部分なのですが、RNNCellの再利用のルールが厳しくなっており、単純にコードを1.0化するだけでは使えないという問題が発生しました。

ValueError: Attempt to reuse RNNCell <tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.BasicLSTMCell object at 0x10210d5c0> with a different variable scope than its first use. First use of cell was with scope 'rnn/multi_rnn_cell/cell_0/basic_lstm_cell', this attempt is with scope 'rnn/multi_rnn_cell/cell_1/basic_lstm_cell'. Please create a new instance of the cell if you would like it to use a different set of weights. If before you were using: MultiRNNCell([BasicLSTMCell(...)] * num_layers), change to: MultiRNNCell([BasicLSTMCell(...) for _ in range(num_layers)]). If before you were using the same cell instance as both the forward and reverse cell of a bidirectional RNN, simply create two instances (one for forward, one for reverse). In May 2017, we will start transitioning this cell's behavior to use existing stored weights, if any, when it is called with scope=None (which can lead to silent model degradation, so this error will remain until then.) 

github.com

 実際のエラーメッセージを上げると上記のエラーになります。同じ変数を違う名前空間で使用することができないといった内容でしょうか。

チュートリアル上でいうと、TF0.12対応のチュートリアルでは、エンコーダで使用するRNNCellとデコーダで使用するRNNCellが同一のものが使われているのですが、仕様上それができなくなったようです。

そのため、エンコーダとデコーダのRNNCellを分けて定義する必要があります。

また、multiRNNCellの定義の仕方をMultiRNNCell([BasicLSTMCell(...)] * num_layers), から MultiRNNCell([BasicLSTMCell(...) for _ in range(num_layers)]).に変えないといけないという指示も出ています。

おそらく、チュートリアルを動かすだけの場合はこの対応のみで回るんでしょうが、今回はカスタマイズしたため、全体的にRNNCellの定義を変更する必要が出てきました。

 

RNNCellが再利用されるべきものか新たに定義すべきものか

ここで気を付けたのが、RNNCellがどのような動きをすべきかというところです。

再利用されるべきものの場合名前空間(VariableScope)に引数として、Reuse=Trueを追加するという対応になりますが、別のものの場合Cell自体の変数定義を別に使用しないといけません。チュートリアルのSeq2SeqModel.pyではRNNCellを別の関数間で再利用しているものがあるので、形は同様の別のRNNCellを利用する場合は新たに変数としてCellを定義してやらないといけません。

同ファイルではDeepCopyという関数を使ってそれを行っていますが、どうもうまくいきません。(DeepCopyでは同じものを使っていると認識するのでしょうか?ごめんなさいよくわかりません)

結構長期間の格闘の末、何とか動くようにはなりましたが、どうもTF1.0~1.1でSeq2Seqチュートリアルを行うのは無理筋のようです。(URLは失念しましたが調べている間に同バージョンでSeq2Seqチュートリアルを使用するのは非推奨という記載も見かけました)

 

Seq2Seqチュートリアルを利用する場合は0.12のままで、もしくは新たなチュートリアルの登場を待ったほうがよさげ

TF0.xのSeq2Seqライブラリについては、TF1.xではregacy_seq2seqライブラリとして遺産的に扱われています。変更を加える予定で(1.2ですでに変更されているかもしれませんが)残されているような形になっています。

今回の1.2への変更でRNNCell周りの仕様が大幅に変更されているっぽいのでそれを基に新たなSeq2seqチュートリアルが出てくるかと期待しているのですが、現状Seq2seqチュートリアルをカスタマイズする場合はTF0.xで行うほうがよさげです。