## @package attention # Module caffe2.python.attention from caffe2.python import brew class AttentionType: Regular, Recurrent, Dot, SoftCoverage = tuple(range(4)) def s(scope, name): # We have to manually scope due to our internal/external blob # relationships. return "{}/{}".format(str(scope), str(name)) # c_i = \sum_j w_{ij}\textbf{s}_j def _calc_weighted_context( model, encoder_outputs_transposed, encoder_output_dim, attention_weights_3d, scope, ): # [batch_size, encoder_output_dim, 1] attention_weighted_encoder_context = brew.batch_mat_mul( model, [encoder_outputs_transposed, attention_weights_3d], s(scope, 'attention_weighted_encoder_context'), ) # [batch_size, encoder_output_dim] attention_weighted_encoder_context, _ = model.net.Reshape( attention_weighted_encoder_context, [ attention_weighted_encoder_context, s(scope, 'attention_weighted_encoder_context_old_shape'), ], shape=[1, -1, encoder_output_dim], ) return attention_weighted_encoder_context # Calculate a softmax over the passed in attention energy logits def _calc_attention_weights( model, attention_logits_transposed, scope, encoder_lengths=None, ): if encoder_lengths is not None: attention_logits_transposed = model.net.SequenceMask( [attention_logits_transposed, encoder_lengths], ['masked_attention_logits'], mode='sequence', ) # [batch_size, encoder_length, 1] attention_weights_3d = brew.softmax( model, attention_logits_transposed, s(scope, 'attention_weights_3d'), engine='CUDNN', axis=1, ) return attention_weights_3d # e_{ij} = \textbf{v}^T tanh \alpha(\textbf{h}_{i-1}, \textbf{s}_j) def _calc_attention_logits_from_sum_match( model, decoder_hidden_encoder_outputs_sum, encoder_output_dim, scope, ): # [encoder_length, batch_size, encoder_output_dim] decoder_hidden_encoder_outputs_sum = model.net.Tanh( decoder_hidden_encoder_outputs_sum, decoder_hidden_encoder_outputs_sum, ) # [encoder_length, batch_size, 1] attention_logits = brew.fc( model, decoder_hidden_encoder_outputs_sum, s(scope, 'attention_logits'), dim_in=encoder_output_dim, dim_out=1, axis=2, freeze_bias=True, ) # [batch_size, encoder_length, 1] attention_logits_transposed = brew.transpose( model, attention_logits, s(scope, 'attention_logits_transposed'), axes=[1, 0, 2], ) return attention_logits_transposed # \textbf{W}^\alpha used in the context of \alpha_{sum}(a,b) def _apply_fc_weight_for_sum_match( model, input, dim_in, dim_out, scope, name, ): output = brew.fc( model, input, s(scope, name), dim_in=dim_in, dim_out=dim_out, axis=2, ) output = model.net.Squeeze( output, output, dims=[0], ) return output # Implement RecAtt due to section 4.1 in http://arxiv.org/abs/1601.03317 def apply_recurrent_attention( model, encoder_output_dim, encoder_outputs_transposed, weighted_encoder_outputs, decoder_hidden_state_t, decoder_hidden_state_dim, attention_weighted_encoder_context_t_prev, scope, encoder_lengths=None, ): weighted_prev_attention_context = _apply_fc_weight_for_sum_match( model=model, input=attention_weighted_encoder_context_t_prev, dim_in=encoder_output_dim, dim_out=encoder_output_dim, scope=scope, name='weighted_prev_attention_context', ) weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match( model=model, input=decoder_hidden_state_t, dim_in=decoder_hidden_state_dim, dim_out=encoder_output_dim, scope=scope, name='weighted_decoder_hidden_state', ) # [1, batch_size, encoder_output_dim] decoder_hidden_encoder_outputs_sum_tmp = model.net.Add( [ weighted_prev_attention_context, weighted_decoder_hidden_state, ], s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'), ) # [encoder_length, batch_size, encoder_output_dim] decoder_hidden_encoder_outputs_sum = model.net.Add( [ weighted_encoder_outputs, decoder_hidden_encoder_outputs_sum_tmp, ], s(scope, 'decoder_hidden_encoder_outputs_sum'), broadcast=1, ) attention_logits_transposed = _calc_attention_logits_from_sum_match( model=model, decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum, encoder_output_dim=encoder_output_dim, scope=scope, ) # [batch_size, encoder_length, 1] attention_weights_3d = _calc_attention_weights( model=model, attention_logits_transposed=attention_logits_transposed, scope=scope, encoder_lengths=encoder_lengths, ) # [batch_size, encoder_output_dim, 1] attention_weighted_encoder_context = _calc_weighted_context( model=model, encoder_outputs_transposed=encoder_outputs_transposed, encoder_output_dim=encoder_output_dim, attention_weights_3d=attention_weights_3d, scope=scope, ) return attention_weighted_encoder_context, attention_weights_3d, [ decoder_hidden_encoder_outputs_sum, ] def apply_regular_attention( model, encoder_output_dim, encoder_outputs_transposed, weighted_encoder_outputs, decoder_hidden_state_t, decoder_hidden_state_dim, scope, encoder_lengths=None, ): weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match( model=model, input=decoder_hidden_state_t, dim_in=decoder_hidden_state_dim, dim_out=encoder_output_dim, scope=scope, name='weighted_decoder_hidden_state', ) # [encoder_length, batch_size, encoder_output_dim] decoder_hidden_encoder_outputs_sum = model.net.Add( [weighted_encoder_outputs, weighted_decoder_hidden_state], s(scope, 'decoder_hidden_encoder_outputs_sum'), broadcast=1, use_grad_hack=1, ) attention_logits_transposed = _calc_attention_logits_from_sum_match( model=model, decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum, encoder_output_dim=encoder_output_dim, scope=scope, ) # [batch_size, encoder_length, 1] attention_weights_3d = _calc_attention_weights( model=model, attention_logits_transposed=attention_logits_transposed, scope=scope, encoder_lengths=encoder_lengths, ) # [batch_size, encoder_output_dim, 1] attention_weighted_encoder_context = _calc_weighted_context( model=model, encoder_outputs_transposed=encoder_outputs_transposed, encoder_output_dim=encoder_output_dim, attention_weights_3d=attention_weights_3d, scope=scope, ) return attention_weighted_encoder_context, attention_weights_3d, [ decoder_hidden_encoder_outputs_sum, ] def apply_dot_attention( model, encoder_output_dim, # [batch_size, encoder_output_dim, encoder_length] encoder_outputs_transposed, # [1, batch_size, decoder_state_dim] decoder_hidden_state_t, decoder_hidden_state_dim, scope, encoder_lengths=None, ): if decoder_hidden_state_dim != encoder_output_dim: weighted_decoder_hidden_state = brew.fc( model, decoder_hidden_state_t, s(scope, 'weighted_decoder_hidden_state'), dim_in=decoder_hidden_state_dim, dim_out=encoder_output_dim, axis=2, ) else: weighted_decoder_hidden_state = decoder_hidden_state_t # [batch_size, decoder_state_dim] squeezed_weighted_decoder_hidden_state = model.net.Squeeze( weighted_decoder_hidden_state, s(scope, 'squeezed_weighted_decoder_hidden_state'), dims=[0], ) # [batch_size, decoder_state_dim, 1] expanddims_squeezed_weighted_decoder_hidden_state = model.net.ExpandDims( squeezed_weighted_decoder_hidden_state, squeezed_weighted_decoder_hidden_state, dims=[2], ) # [batch_size, encoder_output_dim, 1] attention_logits_transposed = model.net.BatchMatMul( [ encoder_outputs_transposed, expanddims_squeezed_weighted_decoder_hidden_state, ], s(scope, 'attention_logits'), trans_a=1, ) # [batch_size, encoder_length, 1] attention_weights_3d = _calc_attention_weights( model=model, attention_logits_transposed=attention_logits_transposed, scope=scope, encoder_lengths=encoder_lengths, ) # [batch_size, encoder_output_dim, 1] attention_weighted_encoder_context = _calc_weighted_context( model=model, encoder_outputs_transposed=encoder_outputs_transposed, encoder_output_dim=encoder_output_dim, attention_weights_3d=attention_weights_3d, scope=scope, ) return attention_weighted_encoder_context, attention_weights_3d, [] def apply_soft_coverage_attention( model, encoder_output_dim, encoder_outputs_transposed, weighted_encoder_outputs, decoder_hidden_state_t, decoder_hidden_state_dim, scope, encoder_lengths, coverage_t_prev, coverage_weights, ): weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match( model=model, input=decoder_hidden_state_t, dim_in=decoder_hidden_state_dim, dim_out=encoder_output_dim, scope=scope, name='weighted_decoder_hidden_state', ) # [encoder_length, batch_size, encoder_output_dim] decoder_hidden_encoder_outputs_sum_tmp = model.net.Add( [weighted_encoder_outputs, weighted_decoder_hidden_state], s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'), broadcast=1, ) # [batch_size, encoder_length] coverage_t_prev_2d = model.net.Squeeze( coverage_t_prev, s(scope, 'coverage_t_prev_2d'), dims=[0], ) # [encoder_length, batch_size] coverage_t_prev_transposed = brew.transpose( model, coverage_t_prev_2d, s(scope, 'coverage_t_prev_transposed'), ) # [encoder_length, batch_size, encoder_output_dim] scaled_coverage_weights = model.net.Mul( [coverage_weights, coverage_t_prev_transposed], s(scope, 'scaled_coverage_weights'), broadcast=1, axis=0, ) # [encoder_length, batch_size, encoder_output_dim] decoder_hidden_encoder_outputs_sum = model.net.Add( [decoder_hidden_encoder_outputs_sum_tmp, scaled_coverage_weights], s(scope, 'decoder_hidden_encoder_outputs_sum'), ) # [batch_size, encoder_length, 1] attention_logits_transposed = _calc_attention_logits_from_sum_match( model=model, decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum, encoder_output_dim=encoder_output_dim, scope=scope, ) # [batch_size, encoder_length, 1] attention_weights_3d = _calc_attention_weights( model=model, attention_logits_transposed=attention_logits_transposed, scope=scope, encoder_lengths=encoder_lengths, ) # [batch_size, encoder_output_dim, 1] attention_weighted_encoder_context = _calc_weighted_context( model=model, encoder_outputs_transposed=encoder_outputs_transposed, encoder_output_dim=encoder_output_dim, attention_weights_3d=attention_weights_3d, scope=scope, ) # [batch_size, encoder_length] attention_weights_2d = model.net.Squeeze( attention_weights_3d, s(scope, 'attention_weights_2d'), dims=[2], ) coverage_t = model.net.Add( [coverage_t_prev, attention_weights_2d], s(scope, 'coverage_t'), broadcast=1, ) return ( attention_weighted_encoder_context, attention_weights_3d, [decoder_hidden_encoder_outputs_sum], coverage_t, )