Recurrent Networks and LSTM deep dive
-
Upload
alex-kalinin -
Category
Technology
-
view
27 -
download
0
Transcript of Recurrent Networks and LSTM deep dive
Content
1. Example of Vanilla RNN2. RNN Forward pass3. RNN Backward pass4. LSTM design
RNN Training problem
Feed-forward (“vanilla”) network
1
0
0
1
0
X
y
RNN
h
𝑊 hh
𝑊 h𝑦
𝑊 h𝑥
Vanilla recurrent network
1¿h𝑡= tanh (𝑊 hh h𝑡−1+𝑊 h𝑥 𝑥+𝑏h )
2¿ 𝑦=𝑊 h𝑦h𝑡+𝑏 𝑦
Example: character-level language processing
X
y
RNN
Training sequence: ”hello”
Vocabulary: [e, h, l, o]
0100
1000
0010
0001
“h”“e” “l” “0”
𝑊 hh
𝑊 h𝑦
𝑊 h𝑥
hX Y
𝑊 h𝑥 =[3 .6 −4.8 0.35 −0.26 ]
𝑊 h𝑦=[ −12.−0.67−0.8514. ]
P
𝑏𝑦=[−0.2−2.96.1−3.4 ]
“hello” RNN
hX Y P
0100
“h”
h0=0
“h”
hX Y P
0100
“h”
h𝑡=tanh (𝑊 hh h𝑡− 1+𝑊 h𝑥 𝑥+𝑏h )
h0=0
“h”
hX Y P
0100
“h”
h=−0.99
“h”
hX Y P
0100
“h”
h=−0.99 𝑦=𝑊 h𝑦 h𝑡+𝑏 𝑦
“h”
hX Y P
0100
“h”
h=−0.99 𝑦=[ 11.−2.26.9−17 ]
“h”
hX Y P
0100
“h”
h=−0.99 𝑦=[ 11.−2.26.9−17 ] 𝑝=[0 .9900.010 ]
“h”
hX Y P
0100
“h”
h=−0.99 𝑦=[ 11.−2.26.9−17 ] 𝑝=[0 .9900.010 ]
1000
“e”“h”
hX Y P
1000
“e”
h=−0.99
“h” “e”
hX Y P
1000
“e”
h=−0.99h𝑡=tanh (𝑊 hh h𝑡− 1+𝑊 h𝑥 𝑥+𝑏h )
“h” “e”
hX Y P
1000
“e”
h=−0.09
“h” “e”
hX Y P
1000
“e”
h=−0.09 𝑦=𝑊 h𝑦 h𝑡+𝑏 𝑦
“h” “e”
hX Y P
1000
“e”
h=−0.09 𝑦=[ 0 .86−2.86.2−4.6 ]
“h” “e”
hX Y P
1000
“e”
h=−0.09 𝑦=[ 0 .86−2.86.2−4.6 ] 𝑝=[ 000.990 ]
“h” “e”
hX Y P
1000
“e”
h=−0.09 𝑦=[ 0 .86−2.86.2−4.6 ] 𝑝=[ 000.990 ]
0010
“l”“h” “e”
hX Y P
0010
“l”
h=−0.09
“h” “e” “l”
hX Y P
0010
“l”
38
“h” “e” “l”
hX Y P
0010
“l”
38 𝑦=[−4.7−3.25.81.9 ]
“h” “e” “l”
hX Y P
0010
“l”
38 𝑦=[−4.7−3.25.81.9 ] 𝑝=[ 000.980.02]
“h” “e” “l”
hX Y P
0010
“l”
38 𝑦=[−4.7−3.25.81.9 ] 𝑝=[ 000.980.02]
0010
“l”“h” “e” “l”
hX Y P
0010
“l”
38
“h” “e” “l” “l”
hX Y P
0010
“l”
98
“h” “e” “l” “l”
hX Y P
0010
“l”
98
“h” “e” “l” “l”
𝑦=[−12.−3.65.310. ]
hX Y P
0010
“l”
98
“h” “e” “l” “l”
𝑦=[−12.−3.65.310. ] 𝑝=[ 000.010.99 ]
hX Y P
0010
“l”
98
“h” “e” “l” “l”
𝑦=[−12.−3.65.310. ] 𝑝=[ 000.010.99 ]
0001
“o”
hX Y P
98
“h” “e” “l” “l” “o”
hX Y P
“h” h0=0 “e”⨁
“e” -0.99 “l”⨁
“l” -0.09 “l”⨁
“l” 0.38 “o”⨁
hX Y P
“hello” “hello”
“hello ben” “hello ben”
“hello world” “hello world”
hX Y P
“it was” “it was”
“it was the” “it was the”
“it was the best” “it was the best”
“It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness… “, A Tale of Two Cities, Charles Dickens
50,000
300,000 (loss = 1.6066)
1,000,000 (loss = 1.8197)
“it was the best of” “it wes the best of” 2,000,000 (loss = 4.0844)
hX Y P
…epoch 500000, loss: 6.447782290456328 …epoch 1000000, loss: 5.290576956983398 …epoch 1800000, loss: 4.267105168323299 epoch 1900000, loss: 4.175163586546514 epoch 2000000, loss: 4.0844739848413285
X
y
RNN
h
𝑊 hh
𝑊 h𝑦
𝑊 h𝑥
Vanilla recurrent network
1¿h𝑡= tanh (𝑊 hh h𝑡−1+𝑊 h𝑥 𝑥+𝑏h )
2¿ 𝑦=𝑊 h𝑦h𝑡+𝑏 𝑦
Input:
Target:
i t “ “ w a s “ “
t “ “ w a s “ “ t h
t
RNNs for Different Problems
Vanilla Neural Network
RNNs for Different Problems
Image Captioningimage -> sequence of words
RNNs for Different Problems
Sentiment Analysissequence of words -> class
RNNs for Different Problems
Translationsequence of words -> sequence of words
h1h0
1 1 2
3
h2
𝑥0 𝑥1 𝑥2
𝐿= 𝑓 (𝑊 h𝑥 ,𝑊 hh ,𝑊 h𝑦)
51
𝑊 hh=0.024
𝑤 h𝑥 ≔𝑤 h𝑥 −0.01 ∙𝜕𝐿𝜕𝑤 h𝑥
𝑤hh≔𝑤hh−0.01 ∙𝜕𝐿𝜕𝑤hh
𝑤h𝑦≔𝑤h𝑦−0.01∙𝜕𝐿𝜕𝑤h𝑦
Training is hard with vanilla RNNs
𝛻 𝐿=[𝜕𝐿𝜕𝑤 h𝑥
, 𝜕𝐿𝜕𝑤hh, 𝜕𝐿𝜕𝑤h 𝑦
]
𝑊 h𝑥
𝑊 hh
𝑊 h𝑦
<— Forward pass
<— Backward pass
h1h0
1 1 2
3
h2
𝑥0 𝑥1 𝑥2
𝜕𝐿𝜕𝑤hh
=?
𝐿=?
y
𝜕𝐿𝜕𝑤=
𝜕 𝑓𝜕𝑔 ∙
𝜕𝑔𝜕h ∙
𝜕h𝜕𝑘 ∙
𝜕𝑘𝜕 𝑙 ∙
𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤𝐿= 𝑓 (𝑔 (h(𝑘 (𝑙 (𝑚 (𝑛 (𝑤)))))))
𝜕𝐿𝜕𝑤hh
=?
𝐿=(( 𝑊 hh tanh (𝑊 hh tanh (𝑊 hh tanh (𝑊 h𝑥 𝑥0)+𝑊 h𝑥 𝑥1)+𝑊 h𝑥 𝑥2))−3)2
Compute gradient
Recursive application of chain rule:
𝜕𝐿𝜕𝑤=?
𝑓 = 𝑓 (𝑔)𝑔=𝑔(h)h=h (𝑘)
Gradient by hand
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
𝑊 hh=0.024
1
Forward Pass
0.078
1.
𝑊 h𝑥
𝑥0
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
0.078
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
0.078
tanh0.0778
h0
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
h0
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
h0
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
0.078
1.
𝑊 h𝑥
𝑥1
h0
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970tanh
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
024
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
024
*0.0019
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
0.078
2.
𝑊 h𝑥
𝑥2
024
*0.0019
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
tanh
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
tanh
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
tanh
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+-2.99
tanh
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
1
𝑊 hh=0.024
Forward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
tanh
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
𝜕𝐿𝜕𝑤=
𝜕 𝑓𝜕𝑔 ∙
𝜕𝑔𝜕h ∙
𝜕h𝜕𝑘 ∙
𝜕𝑘𝜕 𝑙 ∙
𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤
𝐿= 𝑓 (𝑔 (h(𝑘 (𝑙 (𝑚 (𝑛 (𝑤)))))))
𝜕𝐿𝜕𝑤hh
=?
Compute gradient
Recursive application of chain rule:
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
𝜕𝐿𝜕𝑤 h𝑥
=𝜕 𝑓𝜕𝑔 ∙
𝜕𝑔𝜕h ∙
𝜕h𝜕𝑘 ∙
𝜕𝑘𝜕𝑙 ∙
𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤 h𝑥
tanh
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝜕 𝑓𝜕𝑔 ∙
𝜕𝑔𝜕h ∙
𝜕h𝜕𝑘 ∙
𝜕𝑘𝜕𝑙 ∙
𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤 h𝑥
1
tanh
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝜕𝑔𝜕h ∙
𝜕 h𝜕𝑘 ∙
𝜕𝑘𝜕𝑙 ∙
𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤 h𝑥
1
𝜕 𝑓𝜕𝑔=?
tanh
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝜕𝑔𝜕h ∙
𝜕 h𝜕𝑘 ∙
𝜕𝑘𝜕𝑙 ∙
𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤 h𝑥
1
𝜕 𝑓𝜕𝑔=
𝜕𝑔2𝜕𝑔 =2𝑔=2 (−2.99 )=−5.98
-5.98
tanh
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝝏𝒈𝝏𝒉 ∙
𝜕 h𝜕𝑘 ∙
𝜕𝑘𝜕𝑙 ∙
𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤 h𝑥
1-5.98
𝜕𝑔𝜕h=1
-5.98
tanh
tanh𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝝏𝒈𝝏𝒉 ∙
𝝏𝒉𝝏𝒌 ∙
𝜕𝑘𝜕𝑙 ∙
𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤 h𝑥
1-5.98
-5.98
𝜕 h𝜕𝑘=𝑊 h𝑦
0.051tanh
tanh
𝜕h𝜕𝑊 h𝑦
=𝑘
0.1566
-0.304
0.936
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝝏𝒈𝝏𝒉 ∙
𝝏𝒉𝝏𝒌 ∙
𝜕𝑘𝜕𝑙 ∙
𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤 h𝑥
1-5.98
-5.98
𝜕 h𝜕𝑘=𝑊 h𝑦
tanh
tanh
𝜕h𝜕𝑊 h𝑦
=𝑘
-0.304
0.936
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝝏𝒈𝝏𝒉 ∙
𝝏𝒉𝝏𝒌 ∙
𝝏𝒌𝝏𝒍 ∙
𝜕𝑙𝜕𝑚 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤 h𝑥
1-5.98
-5.98
𝜕𝑘𝜕𝑙 =1−𝑘
2=1− .15662=.975
-0.304-0.297tanh
tanh
0.936
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.07970
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝝏𝒈𝝏𝒉 ∙
𝝏𝒉𝝏𝒌 ∙
𝝏𝒌𝝏𝒍 ∙
𝝏 𝒍𝝏𝒎 ∙ 𝜕𝑚𝜕𝑛 ∙
𝜕𝑛𝜕𝑤 h𝑥
1-5.98
-5.98
-0.297tanh
tanh-0.297-0.0071
0.936
-0.304
-0.297
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.0797
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝝏𝒈𝝏𝒉 ∙
𝝏𝒉𝝏𝒌 ∙
𝝏𝒌𝝏𝒍 ∙
𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝜕𝑛𝜕𝑤 h𝑥
1-5.98
-5.98
-0.297tanh
tanh-0.297-0.0071
1−𝑘2=1− .07972=.993
-0.0071
0.936
-0.304
-0.297
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.0797
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
1-5.98
-5.98
-0.297tanh
tanh-0.297-0.0071-0.0071
-0.0071
-0.00017
0.936
-0.304
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝝏𝒈𝝏𝒉 ∙
𝝏𝒉𝝏𝒌 ∙
𝝏𝒌𝝏𝒍 ∙
𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝜕𝑛𝜕𝑤 h𝑥
-0.0005
-0.297
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.0797
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
1-5.98
-5.98
-0.297tanh
tanh-0.297-0.0071-0.0071
-0.0071
-0.00017
1−𝑘2=1− .07782=.993
0.936
-0.304
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝝏𝒈𝝏𝒉 ∙
𝝏𝒉𝝏𝒌 ∙
𝝏𝒌𝝏𝒍 ∙
𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝜕𝑛𝜕𝑤 h𝑥
-0.00017
-0.0005
-0.297
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.0797
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
1-5.98
-5.98
-0.297tanh
tanh-0.297-0.0071-0.0071
-0.0071
-0.00017
0.936
-0.304
𝜕𝐿𝜕𝑤 h𝑥
=𝝏 𝒇𝝏 𝒇 ∙
𝝏 𝒇𝝏𝒈 ∙
𝝏𝒈𝝏𝒉 ∙
𝝏𝒉𝝏𝒌 ∙
𝝏𝒌𝝏𝒍 ∙
𝝏 𝒍𝝏𝒎 ∙ 𝝏𝒎𝝏𝒏 ∙ 𝝏𝒏
𝝏𝒘 𝒙𝒉
-0.00017
-0.00017
-0.0005
-0.297
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh0.0778
*0.00187
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+0.07987
h1
0.0797
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*0.0019
+0.1579 0.1566
h2
0.051𝑊 h𝑦
*0.0080𝑦
-3
+ **
-2.99 8.95
𝐿
1-5.98
-5.98
-0.297tanh
tanh-0.297-0.0071-0.0071
-0.0071
-0.00017
0.936
-0.304
-0.00017
-0.00017
-0.0005
-0.297𝑤𝑎≔𝑤𝑎−0.01 ∙
𝜕𝐿𝜕𝑤𝑎
𝑤 h𝑥 ≔0.078−0.01∙ (− .00017 )=0.0780017
𝑤hh≔0.024−0.01 ∙ (− .0005 )=0.024005
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
Backward Pass
*
0.078
1.
𝑊 h𝑥
𝑥0
024
0.078
tanh
*
*
0.078
1.
𝑊 h𝑥
𝑥1
0.078
h0
+
h1
*
0.078
2.
𝑊 h𝑥
𝑥2
0.156
024
*
+0.1579
0.051𝑊 h𝑦
*
+ **
1-5.98
tanh
tanh-0.297-0.0071
-0.0071
-0.00017
𝑥1𝑥0
h1h0
1 2
h2
𝑥2
3
1
𝜕𝐿𝜕 𝑥=𝑤hh…𝑤hh…𝑤hh…𝑤hh=𝑤hh
𝑛 ∙𝐶 (𝑤)
𝑤hh𝑤hh𝑤hh𝑤hh𝑤hh
1. 0.024 2. 0.000576 3. 1.382e-05 4. 3.318e-07 5. 7.963e-09 6. 1.911e-10 7. 4.586e-12 8. 1.101e-13 9. 2.642e-1510. 6.340e-17
𝑊 hh=0.024tanh tanhtanhtanhtanhtanh
Source: https://imgur.com/gallery/vaNahKE
W
x
2n
4n
(𝑖𝑓𝑜𝑔)=(
𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚
h𝑡𝑎𝑛 )𝑊 ( 𝑥h𝑡−1)
𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔
h𝑡=𝑜 ∙ tanh (𝑐𝑡)
i
f
o
g
x
h
Long Short-Term Memory (LSTM)
n
n
n
n
𝜎
𝜎
𝜎
𝜏
𝑡−1 𝑡
h𝑡=( tanh )𝑊 ( 𝑥h𝑡− 1) - RNN
𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔
h𝑡=tanh (𝑊 hh h𝑡− 1+𝑊 h𝑥 𝑥 )RNN:
LSTM:
(𝑖𝑓𝑜𝑔)=(
𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚
h𝑡𝑎𝑛 )𝑊 ( 𝑥h𝑡−1)
𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔
h𝑡=𝑜 ∙ tanh (𝑐𝑡)
forgetgate,0/1
inputgate, 0/1
f
incomingX
i og
+
X
tanh
X
Long Short-Term Memory (LSTM)
(𝑖𝑓𝑜𝑔)=(
𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚𝑠𝑖𝑔𝑚
h𝑡𝑎𝑛 )𝑊 ( 𝑥h𝑡−1)
𝑐𝑡= 𝑓 ∙𝑐𝑡− 1+ 𝑖∙𝑔
h𝑡=𝑜 ∙ tanh (𝑐𝑡)
𝑐𝑡− 1
h𝑡
𝜕𝐿𝜕 𝑥=𝑤hh…𝑤hh…𝑤hh…𝑤hh=𝑤hh
𝑛 ∙𝐶 (𝑤)
𝑤hh𝑤hh𝑤hh
f f f
f f f
+ + +
RNN
LSTM
Flow of gradient
𝑡−1 𝑡 𝑡+1
𝑡−1 𝑡 𝑡+1
Source: https://imgur.com/gallery/vaNahKE
Long Short-Term Memory (LSTM)
Source: https://colah.github.io/posts/2015-08-Understanding-LSTMs/
Reference
1. Long Term-Short Memory (Hochreiter, 1997), http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
2. Learning Long Term Dependencies With Gradient Descent is Difficult (Yoshua Bengio, 1994), http://www.dsi.unifi.it/~paolo/ps/tnn-94-gradient.pdf
3. http://neuralnetworksanddeeplearning.com/chap5.html
4. Deep Learning, Ian Goodfellow et al., The MIT Press
5. Recurrent Neural Networks, LSTM, Andrej Karpathy, Stanford Lectures, https://www.youtube.com/watch?v=iX5V1WpxxkY
Alex Kalinin [email protected]