Download - Recurrent Networks and LSTM deep dive

Transcript
Page 1: Recurrent Networks and LSTM deep dive

Recurrent Neural Networks

Alex Kalinin [email protected]

Page 2: Recurrent Networks and LSTM deep dive

Content

1. Example of Vanilla RNN2. RNN Forward pass3. RNN Backward pass4. LSTM design

RNN Training problem

Page 3: Recurrent Networks and LSTM deep dive

Feed-forward (β€œvanilla”) network

1

0

0

1

0

Page 4: Recurrent Networks and LSTM deep dive

X

y

RNN

h

π‘Š hh

π‘Š h𝑦

π‘Š hπ‘₯

Vanilla recurrent network

1ΒΏh𝑑= tanh (π‘Š hh hπ‘‘βˆ’1+π‘Š hπ‘₯ π‘₯+𝑏h )

2ΒΏ 𝑦=π‘Š h𝑦h𝑑+𝑏 𝑦

Page 5: Recurrent Networks and LSTM deep dive

Example: character-level language processing

X

y

RNN

Training sequence: ”hello”

Vocabulary: [e, h, l, o]

0100

1000

0010

0001

β€œhβ€β€œe” β€œl” β€œ0”

π‘Š hh

π‘Š h𝑦

π‘Š hπ‘₯

Page 6: Recurrent Networks and LSTM deep dive

hX Y

π‘Š hπ‘₯ =[3 .6 βˆ’4.8 0.35 βˆ’0.26 ]

π‘Š h𝑦=[ βˆ’12.βˆ’0.67βˆ’0.8514. ]

P

𝑏𝑦=[βˆ’0.2βˆ’2.96.1βˆ’3.4 ]

β€œhello” RNN

Page 7: Recurrent Networks and LSTM deep dive

hX Y P

0100

β€œh”

h0=0

β€œh”

Page 8: Recurrent Networks and LSTM deep dive

hX Y P

0100

β€œh”

h𝑑=tanh (π‘Š hh hπ‘‘βˆ’ 1+π‘Š hπ‘₯ π‘₯+𝑏h )

h0=0

β€œh”

Page 9: Recurrent Networks and LSTM deep dive

hX Y P

0100

β€œh”

h=βˆ’0.99

β€œh”

Page 10: Recurrent Networks and LSTM deep dive

hX Y P

0100

β€œh”

h=βˆ’0.99 𝑦=π‘Š h𝑦 h𝑑+𝑏 𝑦

β€œh”

Page 11: Recurrent Networks and LSTM deep dive

hX Y P

0100

β€œh”

h=βˆ’0.99 𝑦=[ 11.βˆ’2.26.9βˆ’17 ]

β€œh”

Page 12: Recurrent Networks and LSTM deep dive

hX Y P

0100

β€œh”

h=βˆ’0.99 𝑦=[ 11.βˆ’2.26.9βˆ’17 ] 𝑝=[0 .9900.010 ]

β€œh”

Page 13: Recurrent Networks and LSTM deep dive

hX Y P

0100

β€œh”

h=βˆ’0.99 𝑦=[ 11.βˆ’2.26.9βˆ’17 ] 𝑝=[0 .9900.010 ]

1000

β€œeβ€β€œh”

Page 14: Recurrent Networks and LSTM deep dive

hX Y P

1000

β€œe”

h=βˆ’0.99

β€œh” β€œe”

Page 15: Recurrent Networks and LSTM deep dive

hX Y P

1000

β€œe”

h=βˆ’0.99h𝑑=tanh (π‘Š hh hπ‘‘βˆ’ 1+π‘Š hπ‘₯ π‘₯+𝑏h )

β€œh” β€œe”

Page 16: Recurrent Networks and LSTM deep dive

hX Y P

1000

β€œe”

h=βˆ’0.09

β€œh” β€œe”

Page 17: Recurrent Networks and LSTM deep dive

hX Y P

1000

β€œe”

h=βˆ’0.09 𝑦=π‘Š h𝑦 h𝑑+𝑏 𝑦

β€œh” β€œe”

Page 18: Recurrent Networks and LSTM deep dive

hX Y P

1000

β€œe”

h=βˆ’0.09 𝑦=[ 0 .86βˆ’2.86.2βˆ’4.6 ]

β€œh” β€œe”

Page 19: Recurrent Networks and LSTM deep dive

hX Y P

1000

β€œe”

h=βˆ’0.09 𝑦=[ 0 .86βˆ’2.86.2βˆ’4.6 ] 𝑝=[ 000.990 ]

β€œh” β€œe”

Page 20: Recurrent Networks and LSTM deep dive

hX Y P

1000

β€œe”

h=βˆ’0.09 𝑦=[ 0 .86βˆ’2.86.2βˆ’4.6 ] 𝑝=[ 000.990 ]

0010

β€œlβ€β€œh” β€œe”

Page 21: Recurrent Networks and LSTM deep dive

hX Y P

0010

β€œl”

h=βˆ’0.09

β€œh” β€œe” β€œl”

Page 22: Recurrent Networks and LSTM deep dive

hX Y P

0010

β€œl”

38

β€œh” β€œe” β€œl”

Page 23: Recurrent Networks and LSTM deep dive

hX Y P

0010

β€œl”

38 𝑦=[βˆ’4.7βˆ’3.25.81.9 ]

β€œh” β€œe” β€œl”

Page 24: Recurrent Networks and LSTM deep dive

hX Y P

0010

β€œl”

38 𝑦=[βˆ’4.7βˆ’3.25.81.9 ] 𝑝=[ 000.980.02]

β€œh” β€œe” β€œl”

Page 25: Recurrent Networks and LSTM deep dive

hX Y P

0010

β€œl”

38 𝑦=[βˆ’4.7βˆ’3.25.81.9 ] 𝑝=[ 000.980.02]

0010

β€œlβ€β€œh” β€œe” β€œl”

Page 26: Recurrent Networks and LSTM deep dive

hX Y P

0010

β€œl”

38

β€œh” β€œe” β€œl” β€œl”

Page 27: Recurrent Networks and LSTM deep dive

hX Y P

0010

β€œl”

98

β€œh” β€œe” β€œl” β€œl”

Page 28: Recurrent Networks and LSTM deep dive

hX Y P

0010

β€œl”

98

β€œh” β€œe” β€œl” β€œl”

𝑦=[βˆ’12.βˆ’3.65.310. ]

Page 29: Recurrent Networks and LSTM deep dive

hX Y P

0010

β€œl”

98

β€œh” β€œe” β€œl” β€œl”

𝑦=[βˆ’12.βˆ’3.65.310. ] 𝑝=[ 000.010.99 ]

Page 30: Recurrent Networks and LSTM deep dive

hX Y P

0010

β€œl”

98

β€œh” β€œe” β€œl” β€œl”

𝑦=[βˆ’12.βˆ’3.65.310. ] 𝑝=[ 000.010.99 ]

0001

β€œo”

Page 31: Recurrent Networks and LSTM deep dive

hX Y P

98

β€œh” β€œe” β€œl” β€œl” β€œo”

Page 32: Recurrent Networks and LSTM deep dive

hX Y P

β€œh” h0=0 β€œe”⨁

β€œe” -0.99 β€œl”⨁

β€œl” -0.09 β€œl”⨁

β€œl” 0.38 β€œo”⨁

Page 33: Recurrent Networks and LSTM deep dive

hX Y P

β€œhello” β€œhello”

β€œhello ben” β€œhello ben”

β€œhello world” β€œhello world”

Page 34: Recurrent Networks and LSTM deep dive

hX Y P

β€œit was” β€œit was”

β€œit was the” β€œit was the”

β€œit was the best” β€œit was the best”

β€œIt was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness… β€œ, A Tale of Two Cities, Charles Dickens

50,000

300,000 (loss = 1.6066)

1,000,000 (loss = 1.8197)

β€œit was the best of” β€œit wes the best of” 2,000,000 (loss = 4.0844)

Page 35: Recurrent Networks and LSTM deep dive

hX Y P

…epoch 500000, loss: 6.447782290456328 …epoch 1000000, loss: 5.290576956983398 …epoch 1800000, loss: 4.267105168323299 epoch 1900000, loss: 4.175163586546514 epoch 2000000, loss: 4.0844739848413285

Page 36: Recurrent Networks and LSTM deep dive

X

y

RNN

h

π‘Š hh

π‘Š h𝑦

π‘Š hπ‘₯

Vanilla recurrent network

1ΒΏh𝑑= tanh (π‘Š hh hπ‘‘βˆ’1+π‘Š hπ‘₯ π‘₯+𝑏h )

2ΒΏ 𝑦=π‘Š h𝑦h𝑑+𝑏 𝑦

Page 37: Recurrent Networks and LSTM deep dive

Input:

Target:

i t β€œ β€œ w a s β€œ β€œ

t β€œ β€œ w a s β€œ β€œ t h

t

Page 38: Recurrent Networks and LSTM deep dive

RNNs for Different Problems

Vanilla Neural Network

Page 39: Recurrent Networks and LSTM deep dive

RNNs for Different Problems

Image Captioningimage -> sequence of words

Page 40: Recurrent Networks and LSTM deep dive

RNNs for Different Problems

Sentiment Analysissequence of words -> class

Page 41: Recurrent Networks and LSTM deep dive

RNNs for Different Problems

Translationsequence of words -> sequence of words

Page 42: Recurrent Networks and LSTM deep dive

h1h0

1 1 2

3

h2

π‘₯0 π‘₯1 π‘₯2

𝐿= 𝑓 (π‘Š hπ‘₯ ,π‘Š hh ,π‘Š h𝑦)

51

π‘Š hh=0.024

𝑀 hπ‘₯ ≔𝑀 hπ‘₯ βˆ’0.01 βˆ™πœ•πΏπœ•π‘€ hπ‘₯

𝑀hh≔𝑀hhβˆ’0.01 βˆ™πœ•πΏπœ•π‘€hh

𝑀h𝑦≔𝑀hπ‘¦βˆ’0.01βˆ™πœ•πΏπœ•π‘€h𝑦

Training is hard with vanilla RNNs

𝛻 𝐿=[πœ•πΏπœ•π‘€ hπ‘₯

, πœ•πΏπœ•π‘€hh, πœ•πΏπœ•π‘€h 𝑦

]

π‘Š hπ‘₯

π‘Š hh

π‘Š h𝑦

<β€” Forward pass

<β€” Backward pass

Page 43: Recurrent Networks and LSTM deep dive

h1h0

1 1 2

3

h2

π‘₯0 π‘₯1 π‘₯2

πœ•πΏπœ•π‘€hh

=?

𝐿=?

y

Page 44: Recurrent Networks and LSTM deep dive

πœ•πΏπœ•π‘€=

πœ• π‘“πœ•π‘” βˆ™

πœ•π‘”πœ•h βˆ™

πœ•hπœ•π‘˜ βˆ™

πœ•π‘˜πœ• 𝑙 βˆ™

πœ•π‘™πœ•π‘š βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€πΏ= 𝑓 (𝑔 (h(π‘˜ (𝑙 (π‘š (𝑛 (𝑀)))))))

πœ•πΏπœ•π‘€hh

=?

𝐿=(( π‘Š hh tanh (π‘Š hh tanh (π‘Š hh tanh (π‘Š hπ‘₯ π‘₯0)+π‘Š hπ‘₯ π‘₯1)+π‘Š hπ‘₯ π‘₯2))βˆ’3)2

Compute gradient

Recursive application of chain rule:

πœ•πΏπœ•π‘€=?

𝑓 = 𝑓 (𝑔)𝑔=𝑔(h)h=h (π‘˜)

Page 45: Recurrent Networks and LSTM deep dive

Gradient by hand

Page 46: Recurrent Networks and LSTM deep dive

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

π‘Š hh=0.024

1

Forward Pass

0.078

1.

π‘Š hπ‘₯

π‘₯0

Page 47: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

0.078

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 48: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

0.078

tanh0.0778

h0

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 49: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

h0

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 50: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

h0

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 51: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

0.078

1.

π‘Š hπ‘₯

π‘₯1

h0

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 52: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 53: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 54: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970tanh

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 55: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

024

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 56: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

024

*0.0019

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 57: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

0.078

2.

π‘Š hπ‘₯

π‘₯2

024

*0.0019

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 58: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 59: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 60: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

tanh

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 61: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

tanh

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 62: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

tanh

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 63: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+-2.99

tanh

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 64: Recurrent Networks and LSTM deep dive

1

π‘Š hh=0.024

Forward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

tanh

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 65: Recurrent Networks and LSTM deep dive

πœ•πΏπœ•π‘€=

πœ• π‘“πœ•π‘” βˆ™

πœ•π‘”πœ•h βˆ™

πœ•hπœ•π‘˜ βˆ™

πœ•π‘˜πœ• 𝑙 βˆ™

πœ•π‘™πœ•π‘š βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€

𝐿= 𝑓 (𝑔 (h(π‘˜ (𝑙 (π‘š (𝑛 (𝑀)))))))

πœ•πΏπœ•π‘€hh

=?

Compute gradient

Recursive application of chain rule:

Page 66: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

πœ•πΏπœ•π‘€ hπ‘₯

=πœ• π‘“πœ•π‘” βˆ™

πœ•π‘”πœ•h βˆ™

πœ•hπœ•π‘˜ βˆ™

πœ•π‘˜πœ•π‘™ βˆ™

πœ•π‘™πœ•π‘š βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€ hπ‘₯

tanh

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 67: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

πœ• π‘“πœ•π‘” βˆ™

πœ•π‘”πœ•h βˆ™

πœ•hπœ•π‘˜ βˆ™

πœ•π‘˜πœ•π‘™ βˆ™

πœ•π‘™πœ•π‘š βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€ hπ‘₯

1

tanh

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 68: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

πœ•π‘”πœ•h βˆ™

πœ• hπœ•π‘˜ βˆ™

πœ•π‘˜πœ•π‘™ βˆ™

πœ•π‘™πœ•π‘š βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€ hπ‘₯

1

πœ• π‘“πœ•π‘”=?

tanh

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 69: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

πœ•π‘”πœ•h βˆ™

πœ• hπœ•π‘˜ βˆ™

πœ•π‘˜πœ•π‘™ βˆ™

πœ•π‘™πœ•π‘š βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€ hπ‘₯

1

πœ• π‘“πœ•π‘”=

πœ•π‘”2πœ•π‘” =2𝑔=2 (βˆ’2.99 )=βˆ’5.98

-5.98

tanh

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 70: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

ππ’ˆππ’‰ βˆ™

πœ• hπœ•π‘˜ βˆ™

πœ•π‘˜πœ•π‘™ βˆ™

πœ•π‘™πœ•π‘š βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€ hπ‘₯

1-5.98

πœ•π‘”πœ•h=1

-5.98

tanh

tanhπ‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 71: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

ππ’ˆππ’‰ βˆ™

ππ’‰ππ’Œ βˆ™

πœ•π‘˜πœ•π‘™ βˆ™

πœ•π‘™πœ•π‘š βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€ hπ‘₯

1-5.98

-5.98

πœ• hπœ•π‘˜=π‘Š h𝑦

0.051tanh

tanh

πœ•hπœ•π‘Š h𝑦

=π‘˜

0.1566

-0.304

0.936

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 72: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

ππ’ˆππ’‰ βˆ™

ππ’‰ππ’Œ βˆ™

πœ•π‘˜πœ•π‘™ βˆ™

πœ•π‘™πœ•π‘š βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€ hπ‘₯

1-5.98

-5.98

πœ• hπœ•π‘˜=π‘Š h𝑦

tanh

tanh

πœ•hπœ•π‘Š h𝑦

=π‘˜

-0.304

0.936

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 73: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

ππ’ˆππ’‰ βˆ™

ππ’‰ππ’Œ βˆ™

ππ’Œππ’ βˆ™

πœ•π‘™πœ•π‘š βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€ hπ‘₯

1-5.98

-5.98

πœ•π‘˜πœ•π‘™ =1βˆ’π‘˜

2=1βˆ’ .15662=.975

-0.304-0.297tanh

tanh

0.936

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 74: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.07970

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

ππ’ˆππ’‰ βˆ™

ππ’‰ππ’Œ βˆ™

ππ’Œππ’ βˆ™

𝝏 π’ππ’Ž βˆ™ πœ•π‘šπœ•π‘› βˆ™

πœ•π‘›πœ•π‘€ hπ‘₯

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071

0.936

-0.304

-0.297

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 75: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

ππ’ˆππ’‰ βˆ™

ππ’‰ππ’Œ βˆ™

ππ’Œππ’ βˆ™

𝝏 π’ππ’Ž βˆ™ ππ’Žππ’ βˆ™ πœ•π‘›πœ•π‘€ hπ‘₯

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071

1βˆ’π‘˜2=1βˆ’ .07972=.993

-0.0071

0.936

-0.304

-0.297

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 76: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

0.936

-0.304

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

ππ’ˆππ’‰ βˆ™

ππ’‰ππ’Œ βˆ™

ππ’Œππ’ βˆ™

𝝏 π’ππ’Ž βˆ™ ππ’Žππ’ βˆ™ πœ•π‘›πœ•π‘€ hπ‘₯

-0.0005

-0.297

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 77: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

1βˆ’π‘˜2=1βˆ’ .07782=.993

0.936

-0.304

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

ππ’ˆππ’‰ βˆ™

ππ’‰ππ’Œ βˆ™

ππ’Œππ’ βˆ™

𝝏 π’ππ’Ž βˆ™ ππ’Žππ’ βˆ™ πœ•π‘›πœ•π‘€ hπ‘₯

-0.00017

-0.0005

-0.297

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 78: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

0.936

-0.304

πœ•πΏπœ•π‘€ hπ‘₯

=𝝏 𝒇𝝏 𝒇 βˆ™

𝝏 π’‡ππ’ˆ βˆ™

ππ’ˆππ’‰ βˆ™

ππ’‰ππ’Œ βˆ™

ππ’Œππ’ βˆ™

𝝏 π’ππ’Ž βˆ™ ππ’Žππ’ βˆ™ 𝝏𝒏

ππ’˜ 𝒙𝒉

-0.00017

-0.00017

-0.0005

-0.297

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 79: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh0.0778

*0.00187

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+0.07987

h1

0.0797

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*0.0019

+0.1579 0.1566

h2

0.051π‘Š h𝑦

*0.0080𝑦

-3

+ **

-2.99 8.95

𝐿

1-5.98

-5.98

-0.297tanh

tanh-0.297-0.0071-0.0071

-0.0071

-0.00017

0.936

-0.304

-0.00017

-0.00017

-0.0005

-0.297π‘€π‘Žβ‰”π‘€π‘Žβˆ’0.01 βˆ™

πœ•πΏπœ•π‘€π‘Ž

𝑀 hπ‘₯ ≔0.078βˆ’0.01βˆ™ (βˆ’ .00017 )=0.0780017

𝑀hh≔0.024βˆ’0.01 βˆ™ (βˆ’ .0005 )=0.024005

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 80: Recurrent Networks and LSTM deep dive

Backward Pass

*

0.078

1.

π‘Š hπ‘₯

π‘₯0

024

0.078

tanh

*

*

0.078

1.

π‘Š hπ‘₯

π‘₯1

0.078

h0

+

h1

*

0.078

2.

π‘Š hπ‘₯

π‘₯2

0.156

024

*

+0.1579

0.051π‘Š h𝑦

*

+ **

1-5.98

tanh

tanh-0.297-0.0071

-0.0071

-0.00017

π‘₯1π‘₯0

h1h0

1 2

h2

π‘₯2

3

1

Page 81: Recurrent Networks and LSTM deep dive

πœ•πΏπœ• π‘₯=𝑀hh…𝑀hh…𝑀hh…𝑀hh=𝑀hh

𝑛 βˆ™πΆ (𝑀)

𝑀hh𝑀hh𝑀hh𝑀hh𝑀hh

1. 0.024 2. 0.000576 3. 1.382e-05 4. 3.318e-07 5. 7.963e-09 6. 1.911e-10 7. 4.586e-12 8. 1.101e-13 9. 2.642e-1510. 6.340e-17

π‘Š hh=0.024tanh tanhtanhtanhtanhtanh

Page 83: Recurrent Networks and LSTM deep dive

W

x

2n

4n

(π‘–π‘“π‘œπ‘”)=(

π‘ π‘–π‘”π‘šπ‘ π‘–π‘”π‘šπ‘ π‘–π‘”π‘š

hπ‘‘π‘Žπ‘› )π‘Š ( π‘₯hπ‘‘βˆ’1)

𝑐𝑑= 𝑓 βˆ™π‘π‘‘βˆ’ 1+ π‘–βˆ™π‘”

h𝑑=π‘œ βˆ™ tanh (𝑐𝑑)

i

f

o

g

x

h

Long Short-Term Memory (LSTM)

n

n

n

n

𝜎

𝜎

𝜎

𝜏

π‘‘βˆ’1 𝑑

h𝑑=( tanh )π‘Š ( π‘₯hπ‘‘βˆ’ 1) - RNN

Page 84: Recurrent Networks and LSTM deep dive

𝑐𝑑= 𝑓 βˆ™π‘π‘‘βˆ’ 1+ π‘–βˆ™π‘”

h𝑑=tanh (π‘Š hh hπ‘‘βˆ’ 1+π‘Š hπ‘₯ π‘₯ )RNN:

LSTM:

(π‘–π‘“π‘œπ‘”)=(

π‘ π‘–π‘”π‘šπ‘ π‘–π‘”π‘šπ‘ π‘–π‘”π‘š

hπ‘‘π‘Žπ‘› )π‘Š ( π‘₯hπ‘‘βˆ’1)

𝑐𝑑= 𝑓 βˆ™π‘π‘‘βˆ’ 1+ π‘–βˆ™π‘”

h𝑑=π‘œ βˆ™ tanh (𝑐𝑑)

forgetgate,0/1

inputgate, 0/1

Page 85: Recurrent Networks and LSTM deep dive

f

incomingX

i og

+

X

tanh

X

Long Short-Term Memory (LSTM)

(π‘–π‘“π‘œπ‘”)=(

π‘ π‘–π‘”π‘šπ‘ π‘–π‘”π‘šπ‘ π‘–π‘”π‘š

hπ‘‘π‘Žπ‘› )π‘Š ( π‘₯hπ‘‘βˆ’1)

𝑐𝑑= 𝑓 βˆ™π‘π‘‘βˆ’ 1+ π‘–βˆ™π‘”

h𝑑=π‘œ βˆ™ tanh (𝑐𝑑)

π‘π‘‘βˆ’ 1

h𝑑

Page 86: Recurrent Networks and LSTM deep dive

πœ•πΏπœ• π‘₯=𝑀hh…𝑀hh…𝑀hh…𝑀hh=𝑀hh

𝑛 βˆ™πΆ (𝑀)

𝑀hh𝑀hh𝑀hh

f f f

f f f

+ + +

RNN

LSTM

Flow of gradient

π‘‘βˆ’1 𝑑 𝑑+1

π‘‘βˆ’1 𝑑 𝑑+1

Page 89: Recurrent Networks and LSTM deep dive

Reference

1. Long Term-Short Memory (Hochreiter, 1997), http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf

2. Learning Long Term Dependencies With Gradient Descent is Difficult (Yoshua Bengio, 1994), http://www.dsi.unifi.it/~paolo/ps/tnn-94-gradient.pdf

3. http://neuralnetworksanddeeplearning.com/chap5.html

4. Deep Learning, Ian Goodfellow et al., The MIT Press

5. Recurrent Neural Networks, LSTM, Andrej Karpathy, Stanford Lectures, https://www.youtube.com/watch?v=iX5V1WpxxkY

Alex Kalinin [email protected]


Top Related