An attempt at parallelizing recurrent nets during training. This is just and idea, which I have not had the time to test yet. If there is work already done on something similar please let me know.

Setup – let’s say you have a set of sequential data:

\[D = \{(X_i, Y_i) \}_{i=0}^{N} \\ X_i = (x_0, x_1, ..., x_j) \\ Y_i = (y_0, y_1, ..., y_j)\]

and you want to train a neural network to uncover the structure in the sequences – predict input from output as an example. There are different ways, depending on the nature of the sequences, that can help you model that, but one of the most universal neural network architectures one could use is some sort of recurrent neural network – RNN.

RNNs alow you to model data with arbitrary sequence size, which is what makes them special. They are characterized with state vector that changes during the processing of every input and stores anything the model deems important.

\[f(s_t, x_t) = [s_{t+1}, y_t]\]

You give the network state ($s_t$) and input ($x_t$) and it spits out the next state ($s_{t+1}$) and output ($y_t$). This looks really natural and generic. The task of the network during training is to learn to remember the important parts of the input and to use them when they are needed.


The problem of this setup is that the process of training is sequential by nature. If you want to compute the output for time step $t$, you need to know the state of the network at time step $t-1$. This recursively shows us that to do an optimization step, one should process forward and then backward the whole sequence, sequentially.

The biggest lesson that can be read from 70 years of AI research is that general methods that leverage computation are ultimately the most effective, and by a large margin.

Rich Sutton – The Bitter Lesson

I’ve been thinking a lot about this lately and how can we scale it up and I came up with and interesting, by my biased opinion, way of training every time step in parallel. That is to say, every time step can be processed forward and backward independently.

How can one do that if the input at every time step is dependent on the output of the previous time step. Well, lets say we have that output vector, already computed and stored somewhere. In a way the computed state becomes trainable, embedding, parameter.

\[(x_i, y_i) - an\ arbitrary\ transition \\ f(e_i, x_i) = [e_{i+i}, y_i]\]

This can easily be done by associating every time transition, of every sequence, with two unique ids and then, during training, we can associate the input and output states, corresponding to these ids with embedding vectors from an embedding layer – hence representation of the transition. Then the gradient signal from backprop can update these vectors, refining them as needed.

The training procedure becomes more memory heavy, but also fully parallel. Which is what dynamic programming algorithms are known for – trading speed for the price of memory.

One problem I see with this is having nonstationary inputs and outputs. Every optimization step changes the inputs and the outputs – the state vectors. So the optimization procedure will be trying to hit a moving target. This can make training unstable, as seen in DQN, but there are ways which can help with this problem.

The thing that comes to mind is what the DQN optimization procedure does – which is to update the target network less frequently. This would mean that we should freeze the embeddings for a few optimization steps, making the inputs and the outputs stationary for a while and accumulating the gradients for the embedding layer somewhere else, and then update. The hope is that this would stabilize the procedure and make the whole process converge faster.

And now for the main problem. Having to do N updates of the weights of the RNN in order for some sort input in time step t, could influence the output of time step t+N. At the end of the day we arrived at somewhat similar problem. Having to do N, sequential, steps to be able to have the chance of passing signal across N time steps. This sounds like part of the nature of modeling sequences is doing things sequentially, which I guess sounds logical and expected.

In conclusion, I think the described method is quite simple, implementation-wise, and maybe something interesting to try out some day.