Beam Search、Tensorflow下如何构建Beam Search

一、Beam Search

二、Tensorflow下构建Beam Search

1. 实现过程

with tf.variable_scope('decoder'):
    beam_width = 10
    memory = encoder_outputs

    if mode == 'infer':
        memory = tf.contrib.seq2seq.tile_batch(memory, beam_width)
        X_len = tf.contrib.seq2seq.tile_batch(X_len, beam_width)
        encoder_state = tf.contrib.seq2seq.tile_batch(encoder_state, beam_width)
        bs = batch_size * beam_width
    else:
        bs = batch_size

    attention = tf.contrib.seq2seq.LuongAttention(hidden_size, memory, X_len, scale=True) # multiplicative
    # attention = tf.contrib.seq2seq.BahdanauAttention(hidden_size, memory, X_len, normalize=True) # additive
    cell = multi_cells(num_layers * 2)
    cell = tf.contrib.seq2seq.AttentionWrapper(cell, attention, hidden_size, name='attention')
    decoder_initial_state = cell.zero_state(bs, tf.float32).clone(cell_state=encoder_state)

    with tf.variable_scope('projected'):
        output_layer = tf.layers.Dense(len(word2id_en), use_bias=False, kernel_initializer=k_initializer)

    if mode == 'infer':
        start = tf.fill([batch_size], word2id_en['<s>'])
        decoder = tf.contrib.seq2seq.BeamSearchDecoder(cell, embeddings_Y, start, word2id_en['</s>'],
                                                       decoder_initial_state, beam_width, output_layer)
        outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(decoder,
                                                                            output_time_major=True,
                                                                            maximum_iterations=2 * tf.reduce_max(X_len))
        sample_id = outputs.predicted_ids
    else:
        helper = tf.contrib.seq2seq.TrainingHelper(embedded_Y, [maxlen_en - 1 for b in range(batch_size)])
        decoder = tf.contrib.seq2seq.BasicDecoder(cell, helper, decoder_initial_state, output_layer)

        outputs, final_context_state, _ = tf.contrib.seq2seq.dynamic_decode(decoder, 
                                                                            output_time_major=True)
        logits = outputs.rnn_output
        logits = tf.transpose(logits, (1, 0, 2))
        print(logits)

2. 注意事项

  • 接口:tf.contrib.seq2seq.BeamSearchDecoder
  • 输入
    • encoder的output(即memory)、encoder的final state(即encoder_state)、encoder端source sentence lengths(即X_len)必须通过tile_batch函数进行复制,最后的shape为(batch_size*beam_width,……);AttentionWrapper的初始化state(zero_state)的输入必须是batch_size*beam_width;输入到BeamSearchDecoder的start为batch_size,之后BeamSearchDecoder初始化的时候会将其复制beam_width倍,变为batch_size*beam_width,该操作源码如下:
 self._start_tokens = array_ops.tile(
    array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
self._start_inputs = self._embedding_fn(self._start_tokens)
  • 输出
    • BasicDecoder的step函数
      • 输出的output:是BasicDecoderOutput(rnn_output, sample_id)
      • 输出的state:根据BasicDecoder内部的cell决定,如果是AttentionWrapper,则是AttentionWrapperState,如果AttentionWrapper中包含的是MultiRNNCell,则AttentionWrapperState.cell_state是一个tuple,如果MultiRNNCell中包含的是LSTM,则是tuple中的元素是LSTMStateTuple,如果包含的是rnn、GRU,则tuple中的元素是Tensor
    • BeamSearchDecoder的step函数
      • 输出的output:是BeamSearchDecoderOutput(scores, predicted_ids, parent_ids)
      • 输出的state:是BeamSearchDecoderState,而BeamSearchDecoderState.cell_state根据BeamSearchDecoder内部的cell决定,如果是AttentionWrapper,则是AttentionWrapperState,如果AttentionWrapper中包含的是MultiRNNCell,则AttentionWrapperState.cell_state是一个tuple,如果MultiRNNCell中包含的是LSTM,则是tuple中的元素是LSTMStateTuple,如果包含的是rnn、GRU,则tuple中的元素是Tensor
    • BasicDecoder经过dynamic_decoder的输出为BasicDecoderOutput(rnn_output, sample_id)
      • 相对于step函数输出的BasicDecoderOutput,其中的rnn_output、sample_id都增加1维,根据time_major决定该维添加到哪一维
    • BeamSearchDecoder经过dynamic_decoder的输出为FinalBeamSearchDecoderOutput(predicted_ids, beam_search_decoder_output)
      • predicted_ids的大小为(batch_size, beam_size, generate_sentence_length)
      • beam_search_decoder_output(scores, predicted_ids, parent_ids)
      • FinalBeamSearchDecoderOutput.beam_search_decoder_output.predicted_ids和FinalBeamSearchDecoderOutput.predicted_ids不一样,前者用来构建整个beam search过程的树,后者根据前者回溯获得最终的输出
  • 参数
    • BeamSearchDecoder相对于BasicDecoder没有helper函数,但是源码中_beam_search_step的功能相当于一个help函数
    • tf.contrib.seq2seq.BeamSearchDecode仅仅实现了Length normalization,由 length_penalty_weight控制
    • 当使用BeamSearchDecoder时,dynamic_decoder中的impute_finished必须设置为False,如果为True,则会报错,原因参照如下源码,当tf.where的x是向量时,则其大小必须与y的第一维大小一致,如果x为张量,则必须与y大小一致,而当使用BeamSearchDecoder时,BeamSearchDecoderState中存在诸如BeamSearchDecoderState.cell_state.cell_state[0].h的张量,其shape为(batch_size, beam_width, decoder_rnn_size),而finished为(batch_size, beam_width),所以报错:
 # Copy through states past finish
      def _maybe_copy_state(new, cur):
        # TensorArrays and scalar states get passed through.
        if isinstance(cur, tensor_array_ops.TensorArray):
          pass_through = True
        else:
          new.set_shape(cur.shape)
          pass_through = (new.shape.ndims == 0)
        return new if pass_through else array_ops.where(finished, cur, new)

      if impute_finished:
        next_state = nest.map_structure(
            _maybe_copy_state, decoder_state, state)
      else:
        next_state = decoder_state

 

3. 报错记录及解决方案

Try doing this early on:

from tensorflow.contrib.seq2seq.python.ops import beam_search_ops

I have the feeling that when importing a graphdef that the dynamic loading of the .so with the GatherTree ops hasn't happened. So adding that import should force the library to load.

4. 源码解析

  • 整个源码主要是构建一棵树,在前馈计算的过程中,每个节点包含两个信息(word id, parent_beam_id, current_score),之后用来回溯获得最终的序列
  • Beam search停止的标志是,所有序列预测到EOS
next_finished = math_ops.logical_or(
      previously_finished,
      math_ops.equal(next_word_ids, end_token),
      name="next_beam_finished")
  • 当一部分序列已经预测到EOS,仍旧有序列没有预测到EOS,则已经预测到EOS的序列在接下来会将下一个word的生成概率分布变为除了EOS为0外,其他都是-INF,则会不断生成EOS,且序列长度不增加,为了避免length_penalty,同时由于EOS的概率为0,则total_score也不会增加,如果之后遇到beam_size个total_score大于已经生成EOS的序列的total_score,则这些生成EOS的序列也可能会被淘汰
相关源码代码块1:
  # Calculate the length of the next predictions.
  # 1. Finished beams remain unchanged.
  # 2. Beams that are now finished (EOS predicted) have their length
  #    increased by 1.
  # 3. Beams that are not yet finished have their length increased by 1.
  lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished))
  next_prediction_len = _tensor_gather_helper(
      gather_indices=next_beam_ids,
      gather_from=beam_state.lengths,
      batch_size=batch_size,
      range_size=beam_width,
      gather_shape=[-1])
  next_prediction_len += lengths_to_add

相关源码代码块2:
  # Calculate the total log probs for the new hypotheses
  # Final Shape: [batch_size, beam_width, vocab_size]
  step_log_probs = nn_ops.log_softmax(logits)
  step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished)
  total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs

  # All finished examples are replaced with a vector that has all
  # probability on EOS
  finished_row = array_ops.one_hot(
      eos_token,
      vocab_size,
      dtype=probs.dtype,
      on_value=ops.convert_to_tensor(0., dtype=probs.dtype),
      off_value=probs.dtype.min)

5. 如何将beam search放到训练过程中来

三、其他解释材料和相关代码

解释

代码

    原文作者:救命稻草人
    原文地址: https://zhuanlan.zhihu.com/p/54222636
    本文转自网络文章,转载此文章仅为分享知识,如有侵权,请联系博主进行删除。
点赞