Recurrent Networks and LSTM deep dive

Post on 13-Apr-2017

27 views 0 download

Transcript of Recurrent Networks and LSTM deep dive

Recurrent Neural Networks

Alex Kalinin alex@alexkalinin.com

Content

1. Example of Vanilla RNN2. RNN Forward pass3. RNN Backward pass4. LSTM design

RNN Training problem

Feed-forward (“vanilla”) network

1

0

0

1

0

X

y

RNN

h

𝑊 hh

𝑊 h𝑦

𝑊 h𝑥

Vanilla recurrent network

1¿h𝑡= tanh (𝑊 hh h𝑡−1+𝑊 h𝑥 𝑥+𝑏h )

2¿ 𝑦=𝑊 h𝑦h𝑡+𝑏 𝑦

Example: character-level language processing

X

y

RNN

Training sequence: ”hello”

Vocabulary: [e, h, l, o]

0100

1000

0010

0001

“h”“e” “l” “0”

𝑊 hh

𝑊 h𝑦

𝑊 h𝑥

hX Y

𝑊 h𝑥 =[3 .6 −4.8 0.35 −0.26 ]

𝑊 h𝑦=[ −12.−0.67−0.8514. ]

P

𝑏𝑦=[−0.2−2.96.1−3.4 ]

“hello” RNN

hX Y P

0100

“h”

h0=0

“h”

hX Y P

0100

“h”

h𝑡=tanh (𝑊 hh h𝑡− 1+𝑊 h𝑥 𝑥+𝑏h )

h0=0

“h”

hX Y P

0100

“h”

h=−0.99

“h”

hX Y P

0100

“h”

h=−0.99 𝑦=𝑊 h𝑦 h𝑡+𝑏 𝑦

“h”

hX Y P

0100

“h”

h=−0.99 𝑦=[ 11.−2.26.9−17 ]

“h”

hX Y P

0100

“h”

h=−0.99 𝑦=[ 11.−2.26.9−17 ] 𝑝=[0 .9900.010 ]

“h”

hX Y P

0100

“h”

h=−0.99 𝑦=[ 11.−2.26.9−17 ] 𝑝=[0 .9900.010 ]

1000

“e”“h”

hX Y P

1000

“e”

h=−0.99

“h” “e”

hX Y P

1000

“e”

h=−0.99h𝑡=tanh (𝑊 hh h𝑡− 1+𝑊 h𝑥 𝑥+𝑏h )

“h” “e”

hX Y P

1000

“e”

h=−0.09

“h” “e”

hX Y P

1000

“e”

h=−0.09 𝑦=𝑊 h𝑦 h𝑡+𝑏 𝑦

“h” “e”

hX Y P

1000

“e”

h=−0.09 𝑦=[ 0 .86−2.86.2−4.6 ]

“h” “e”

hX Y P

1000

“e”

h=−0.09 𝑦=[ 0 .86−2.86.2−4.6 ] 𝑝=[ 000.990 ]

“h” “e”

hX Y P

1000

“e”

h=−0.09 𝑦=[ 0 .86−2.86.2−4.6 ] 𝑝=[ 000.990 ]

0010

“l”“h” “e”

hX Y P

0010

“l”

h=−0.09

“h” “e” “l”

hX Y P

0010

“l”

38

“h” “e” “l”

hX Y P

0010

“l”

38 𝑦=[−4.7−3.25.81.9 ]

“h” “e” “l”

hX Y P

0010

“l”

38 𝑦=[−4.7−3.25.81.9 ] 𝑝=[ 000.980.02]

“h” “e” “l”

hX Y P

0010

“l”

38 𝑦=[−4.7−3.25.81.9 ] 𝑝=[ 000.980.02]

0010

“l”“h” “e” “l”

hX Y P

0010

“l”

38

“h” “e” “l” “l”

hX Y P

0010

“l”

98

“h” “e” “l” “l”

hX Y P

0010

“l”

98

“h” “e” “l” “l”

𝑦=[−12.−3.65.310. ]

hX Y P

0010

“l”

98

“h” “e” “l” “l”

𝑦=[−12.−3.65.310. ] 𝑝=[ 000.010.99 ]

hX Y P

0010

“l”

98

“h” “e” “l” “l”

𝑦=[−12.−3.65.310. ] 𝑝=[ 000.010.99 ]

0001

“o”

hX Y P

98

“h” “e” “l” “l” “o”

hX Y P

“h” h0=0 “e”⨁

“e” -0.99 “l”⨁

“l” -0.09 “l”⨁

“l” 0.38 “o”⨁

hX Y P

“hello” “hello”

“hello ben” “hello ben”

“hello world” “hello world”

hX Y P

“it was” “it was”

“it was the” “it was the”

“it was the best” “it was the best”

“It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness… “, A Tale of Two Cities, Charles Dickens

50,000

300,000 (loss = 1.6066)

1,000,000 (loss = 1.8197)

“it was the best of” “it wes the best of” 2,000,000 (loss = 4.0844)

hX Y P

…epoch 500000, loss: 6.447782290456328 …epoch 1000000, loss: 5.290576956983398 …epoch 1800000, loss: 4.267105168323299 epoch 1900000, loss: 4.175163586546514 epoch 2000000, loss: 4.0844739848413285

X

y

RNN

h

𝑊 hh

𝑊 h𝑦

𝑊 h𝑥

Vanilla recurrent network

1¿h𝑡= tanh (𝑊 hh h𝑡−1+𝑊 h𝑥 𝑥+𝑏h )

2¿ 𝑦=𝑊 h𝑦h𝑡+𝑏 𝑦

Input:

Target:

i t “ “ w a s “ “

t “ “ w a s “ “ t h

t

RNNs for Different Problems

Vanilla Neural Network

RNNs for Different Problems

Image Captioningimage -> sequence of words

RNNs for Different Problems

Sentiment Analysissequence of words -> class

RNNs for Different Problems

Translationsequence of words -> sequence of words

h1h0

1 1 2

3

h2

𝑥0 𝑥1 𝑥2

𝐿= 𝑓 (𝑊 h𝑥 ,𝑊 hh ,𝑊 h𝑦)

51

𝑊 hh=0.024

𝑤 h𝑥 ≔𝑤 h𝑥 −0.01 ∙𝜕𝐿𝜕𝑤 h𝑥

𝑤hh≔𝑤hh−0.01 ∙𝜕𝐿𝜕𝑤hh

𝑤h𝑦≔𝑤h𝑦−0.01∙𝜕𝐿𝜕𝑤h𝑦

Training is hard with vanilla RNNs

𝛻 𝐿=[𝜕𝐿𝜕𝑤 h𝑥

, 𝜕𝐿𝜕𝑤hh, 𝜕𝐿𝜕𝑤h 𝑦

]

𝑊 h𝑥

𝑊 hh

𝑊 h𝑦

<— Forward pass

<— Backward pass

h1h0

1 1 2

3

h2

𝑥0 𝑥1 𝑥2

𝜕𝐿𝜕𝑤hh

=?

𝐿=?

y

𝜕𝐿𝜕𝑤=

𝜕 𝑓𝜕𝑔 ∙

𝜕𝑔𝜕h ∙

𝜕h𝜕𝑘 ∙

𝜕𝑘𝜕 𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤𝐿= 𝑓 (𝑔 (h(𝑘 (𝑙 (𝑚 (𝑛 (𝑤)))))))

𝜕𝐿𝜕𝑤hh

=?

𝐿=(( 𝑊 hh tanh (𝑊 hh tanh (𝑊 hh tanh (𝑊 h𝑥 𝑥0)+𝑊 h𝑥 𝑥1)+𝑊 h𝑥 𝑥2))−3)2

Compute gradient

Recursive application of chain rule:

𝜕𝐿𝜕𝑤=?

𝑓 = 𝑓 (𝑔)𝑔=𝑔(h)h=h (𝑘)

Gradient by hand

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

𝑊 hh=0.024

1

Forward Pass

0.078

1.

𝑊 h𝑥

𝑥0

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

0.078

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

0.078

tanh0.0778

h0

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

h0

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

h0

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

0.078

1.

𝑊 h𝑥

𝑥1

h0

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970tanh

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

024

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

024

*0.0019

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

0.078

2.

𝑊 h𝑥

𝑥2

024

*0.0019

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+-2.99

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

1

𝑊 hh=0.024

Forward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

𝜕𝐿𝜕𝑤=

𝜕 𝑓𝜕𝑔 ∙

𝜕𝑔𝜕h ∙

𝜕h𝜕𝑘 ∙

𝜕𝑘𝜕 𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤

𝐿= 𝑓 (𝑔 (h(𝑘 (𝑙 (𝑚 (𝑛 (𝑤)))))))

𝜕𝐿𝜕𝑤hh

=?

Compute gradient

Recursive application of chain rule:

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝜕 𝑓𝜕𝑔 ∙

𝜕𝑔𝜕h ∙

𝜕h𝜕𝑘 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝜕 𝑓𝜕𝑔 ∙

𝜕𝑔𝜕h ∙

𝜕h𝜕𝑘 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝜕𝑔𝜕h ∙

𝜕 h𝜕𝑘 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1

𝜕 𝑓𝜕𝑔=?

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝜕𝑔𝜕h ∙

𝜕 h𝜕𝑘 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1

𝜕 𝑓𝜕𝑔=

𝜕𝑔2𝜕𝑔 =2𝑔=2 (−2.99 )=−5.98

-5.98

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝜕 h𝜕𝑘 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1-5.98

𝜕𝑔𝜕h=1

-5.98

tanh

tanh𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1-5.98

-5.98

𝜕 h𝜕𝑘=𝑊 h𝑦

0.051tanh

tanh

𝜕h𝜕𝑊 h𝑦

=𝑘

0.1566

-0.304

0.936

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝜕𝑘𝜕𝑙 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1-5.98

-5.98

𝜕 h𝜕𝑘=𝑊 h𝑦

tanh

tanh

𝜕h𝜕𝑊 h𝑦

=𝑘

-0.304

0.936

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1-5.98

-5.98

𝜕𝑘𝜕𝑙 =1−𝑘

2=1− .15662=.975

-0.304-0.297tanh

tanh

0.936

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝝏 𝒍𝝏𝒎 ∙ 𝜕𝑚𝜕𝑛 ∙

𝜕𝑛𝜕𝑤 h𝑥

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071

0.936

-0.304

-0.297

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝜕𝑛𝜕𝑤 h𝑥

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071

1−𝑘2=1− .07972=.993

-0.0071

0.936

-0.304

-0.297

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

0.936

-0.304

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝜕𝑛𝜕𝑤 h𝑥

-0.0005

-0.297

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

1−𝑘2=1− .07782=.993

0.936

-0.304

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝜕𝑛𝜕𝑤 h𝑥

-0.00017

-0.0005

-0.297

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

0.936

-0.304

𝜕𝐿𝜕𝑤 h𝑥

=𝝏 𝒇𝝏 𝒇 ∙

𝝏 𝒇𝝏𝒈 ∙

𝝏𝒈𝝏𝒉 ∙

𝝏𝒉𝝏𝒌 ∙

𝝏𝒌𝝏𝒍 ∙

𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝝏𝒏

𝝏𝒘 𝒙𝒉

-0.00017

-0.00017

-0.0005

-0.297

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051𝑊 h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

0.936

-0.304

-0.00017

-0.00017

-0.0005

-0.297𝑤𝑎≔𝑤𝑎−0.01 ∙

𝜕𝐿𝜕𝑤𝑎

𝑤 h𝑥 ≔0.078−0.01∙ (− .00017 )=0.0780017

𝑤hh≔0.024−0.01 ∙ (− .0005 )=0.024005

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

Backward Pass

*

0.078

1.

𝑊 h𝑥

𝑥0

024

0.078

tanh

*

*

0.078

1.

𝑊 h𝑥

𝑥1

0.078

h0

+

h1

*

0.078

2.

𝑊 h𝑥

𝑥2

0.156

024

*

+0.1579

0.051𝑊 h𝑦

*

+ **

1-5.98

tanh

tanh-0.297-0.0071

-0.0071

-0.00017

𝑥1𝑥0

h1h0

1 2

h2

𝑥2

3

1

𝜕𝐿𝜕 𝑥=𝑤hh…𝑤hh…𝑤hh…𝑤hh=𝑤hh

𝑛 ∙𝐶 (𝑤)

𝑤hh𝑤hh𝑤hh𝑤hh𝑤hh

1. 0.024 2. 0.000576 3. 1.382e-05 4. 3.318e-07 5. 7.963e-09 6. 1.911e-10 7. 4.586e-12 8. 1.101e-13 9. 2.642e-1510. 6.340e-17

𝑊 hh=0.024tanh tanhtanhtanhtanhtanh

W

x

2n

4n

(𝑖𝑓𝑜𝑔)=(

𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚

h𝑡𝑎𝑛 )𝑊 ( 𝑥h𝑡−1)

𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔

h𝑡=𝑜 ∙ tanh (𝑐𝑡)

i

f

o

g

x

h

Long Short-Term Memory (LSTM)

n

n

n

n

𝜎

𝜎

𝜎

𝜏

𝑡−1 𝑡

h𝑡=( tanh )𝑊 ( 𝑥h𝑡− 1) - RNN

𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔

h𝑡=tanh (𝑊 hh h𝑡− 1+𝑊 h𝑥 𝑥 )RNN:

LSTM:

(𝑖𝑓𝑜𝑔)=(

𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚

h𝑡𝑎𝑛 )𝑊 ( 𝑥h𝑡−1)

𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔

h𝑡=𝑜 ∙ tanh (𝑐𝑡)

forgetgate,0/1

inputgate, 0/1

f

incomingX

i og

+

X

tanh

X

Long Short-Term Memory (LSTM)

(𝑖𝑓𝑜𝑔)=(

𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚

h𝑡𝑎𝑛 )𝑊 ( 𝑥h𝑡−1)

𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔

h𝑡=𝑜 ∙ tanh (𝑐𝑡)

𝑐𝑡− 1

h𝑡

𝜕𝐿𝜕 𝑥=𝑤hh…𝑤hh…𝑤hh…𝑤hh=𝑤hh

𝑛 ∙𝐶 (𝑤)

𝑤hh𝑤hh𝑤hh

f f f

f f f

+ + +

RNN

LSTM

Flow of gradient

𝑡−1 𝑡 𝑡+1

𝑡−1 𝑡 𝑡+1

Reference

1. Long Term-Short Memory (Hochreiter, 1997), http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf

2. Learning Long Term Dependencies With Gradient Descent is Difficult (Yoshua Bengio, 1994), http://www.dsi.unifi.it/~paolo/ps/tnn-94-gradient.pdf

3. http://neuralnetworksanddeeplearning.com/chap5.html

4. Deep Learning, Ian Goodfellow et al., The MIT Press

5. Recurrent Neural Networks, LSTM, Andrej Karpathy, Stanford Lectures, https://www.youtube.com/watch?v=iX5V1WpxxkY

Alex Kalinin alex@alexkalinin.com