Seq2Seq with Attention

Brief Outline

  • Previously, we used to take the encoder state of the entire input sentence and use that every time in the decoder step.
  • However, at every decoder time step, we don't require the entire encoder state as the word at that time step does not depend on the entire sentence
  • This would also overload the decoder
  • Can we have a weighted sum of the encoder states at each time steps instead, to tell which encoder states are important?
  • The answer is, attention.

Attention

  • To enable attention, we define a function
    ejt=fATT(stโˆ’1,hj)
  • This quantity captures the importance of the jth input word for decoding the tth output word.
  • Since ejt needs to sum up to one, we apply the softmax function.
    ฮฑjt=exp(ejt)โˆ‘k=1Mexp(ckt)
  • One of many possible choices of fATT is
    fATT=VattT tanh(Uattstโˆ’1+Watthj)
  • Where
    • hjโˆˆโ„d1ร—1
    • stโˆˆโ„d2ร—1
  • And
    • Vattโˆˆโ„d1ร—1
    • Uattโˆˆโ„d1ร—d2
    • Wattโˆˆโ„d1ร—d1
  • Clearly, ฮฑjt will result in a scalar.
  • These parameters will be learned along with the other parameters of the encoder and decoder.

Architecture

Image Not Showing Possible Reasons
  • The image file may be corrupted
  • The server hosting the image is unavailable
  • The image path is incorrect
  • The image format is not supported
Learn More โ†’

Forward propagation

Encoder

  • xj=Word Embeddingsโˆˆโ„e1ร—1
  • hj=RNN(hjโˆ’1,xj)โˆˆโ„d1ร—1

Attention

  • ejt=VattT tanh(Uattstโˆ’1+Watthj)
  • ฮฑjt=softmax(ejt)
  • ct=โˆ‘j=1Tฮฑjthj
  • This ct is the encoder hidden state that will be passed to the decoder at every timestep t to get the decoder hidden state st

Decoder

  • st=RNN(stโˆ’1,ct)
  • lt=softmax(Vst+b)