Long Short-Term Memory

From Psyc 40 Wiki
Jump to: navigation, search

By Alphonso Bradham

Long Short-Term Memory (LSTM) refers to a type of recurrent neural network architecture useful for performing classification, regression, encoding, and decoding tasks on long sequence or time-series data. LSTMs were developed to counter the vanishing gradient problem [1], and the key features of an LSTM network are the inclusion of LSTM "cell state" vectors that allow them to keep track of long range relationships in data that other models would "forget".

Background on Sequence Data and RNNs

An example of a recurrent neural network architecture. As shown, the output of a given layer depends upon both the data present at time t as well as the output of the network on the data at time t-1.

For data where sequence order is important, traditional feed-forward neural networks [2] often struggle to encode the temporal relationships between inputs. While suitable for simple regression and classification [3] tasks on independent samples of data, the simple matrix-vector architecture of feed-forward neural networks does not retain a capacity for "memory" and is unable to accurately handle sequence data (where the output of the network on data point t is dependent on the value of a previously examined data point t-1).

It is in these situations (where a sequential memory is desired) that Recurrent Neural Networks (RNNs) are employed. Put simply, recurrent neural networks differ from traditional feed-forward neural networks in that neurons have connections that point "backwards" to reference earlier observations .[1] Operationally, this translates to RNNs having two inputs at each layer, one that references only the input data at the current time "t" in the traditional manner of feed-forward neural networks, and another that references the network's output at the previous time step "t-1". In this way, the input representing the previous time-step "t-1" serves as the network's "memory" (since the previous output itself also depended on the outputs of the time-steps before it, and so on to the beginning of the sequence). It is this capacity for memory (encoded in RNNs' performance rule) that makes RNNs well suited for handling sequence data tasks.


Vanishing Gradient Problem

As mentioned above, although RNNs' performance rule endows them with a capacity for memory, their learning is still defined by the back-propagation learning rule present in feed-forward neural networks. One of the side effects of the back-propagation algorithm is its tendency to "diminish" the amount that early weights are tuned as the network grows in depth. Because weights are updated proportionally to the derivative of the error function, and because these error-function derivatives become increasingly small as training converges towards local minima, early layers of a network (which are updated last in back-propagation) receive incredibly tiny adjustments in weight values. This phenomenon is known as the Vanishing Gradient problem [2].

Put more simply: the deeper the network, the smaller the changes earlier layers are able to undergo. This becomes a particularly troublesome problem in the case of RNNs since, by nature, they require vastly deep architectures to facilitate their capacity for contextual "memory". Especially in cases where "long range contextualization and memory" is desired, the vanishing gradient problem can severely hinder the ability for RNNs to accurately model and interpret sequence data. Avoiding the penalty of vanishing gradients in RNNs is the primary motivation underpinning LSTM models.


LSTM Architecture

A figure depicting the 3 "gated" behaviors of the cell-state mechanism of a Long Short-Term Memory (LSTM) neural network architecture.

To combat the vanishing gradient problem, LSTM models implement a more sophisticated "gated" system for controlling how memory is passed down through the network. This new mechanism, known as the "cell state", constitutes the LSTM's new concept of "memory", and can only be modified or accessed based on the output of three "gates" that are evaluated in the performance rule of the network.[3] These gates are: The Input Gate, The Output Gate, and The Forget Gate.

As the network is trained and encounters new data points in a given sequence, it has three options on what to do with previous observations. It can:

1. Forget: All previous data points in memory are reset the memory is returned to an empty state (The cell state "forgets" everything before the current datapoint).

2. Update: The memory cell is updated to reflect the value of the current datapoint in combination with all previous datapoint (The current datapoint is "remembered" in the cell state).

3. Output: The cell state is used to modify the output of the hidden state vector at the current state (The current set of "memories" housed in the cell state is used to influence the output of the network at the current time step).


Each of these three options is "Gated", meaning that the network will only choose to undertake these operations if the given data causes the functions controlling these gates to surpass some threshold. By way of analogy, consider the sentence: "I like most pie, but I don't like apple pie.". In this case, an LSTM would choose to forget everything that came before the word "but" since the operative part of this sentence concerns the speaker not liking apple pie. Everything before the "but" is superfluous and only clutters memory, so it is forgotten. By selectively updating, forgetting, and utilizing memory, LSTMs are able to avoid the vanishing gradient problem by making sure that memory operations are only undertaken when certain conditions are met (and not at every time step). This selective relationship with memory means that LSTMs are able to perform incredibly well on datasets where long range contextual understanding and a conditional memory is needed to make useful insights from data.

Applications of LSTMs

Because of their resilience against long-range sequences and their ability to efficiently decode contextual dependencies, LSTM networks are often used in Language Translation [4] [5][6], Speech Recognition [7][8] [9] , and even in learning policies for playing complex video games [10]. LSTMs' capabilities make them a powerful tool in the world of neural network architectures, and the future of LSTMs seems promising.

References