cover_simple_full

LSTM MDN generated handwriting sample and probability density of next point.

cover2a

cover2a1

cover2b

Longer generated handwriting samples.

This implementation is available on github.

In our last post we discussed about Mixture Density Networks and how they can be a useful tool to model data with various states, and rather than try to predict the expected value of a data point, it allows us to predict the entire probability density function of the data, which can be a lot more useful in many applications, and we think it is a very useful too as well for generative tasks.

In this post I will discuss combining MDNs with LSTMs to generate artificial handwriting examples. We will try to implement part of Alex Grave’s work for generating sequential data using RNNs, after studying his paper, and playing around with this awesome demo. For a good introduction to LSTMs and Recurrent Networks, I recommend Colah’s Understanding LSTM Networks and Karparthy’s The Unreasonable Effectiveness of Recurrent Neural Networks, where he discusses the approach of using RNNs to generate sequences of text data, such a shakesphere, LaTeX documents, and even fake Linux Kernel C code! After understanding how that approach works, and a bit about MDNs, you should be able to understand how this algorithm works.

I have implemented this demo in Python using TensorFlow, and I relied on this char-rnn-tensorflow implementation, made by sherjilozair to do character-level text prediction. His example has greatly taught me about how to get LSTMs working in TensorFlow.

Training Data

In order to get our neural network to write anything, it must first train on a relative large set of handwriting examples. We will use the same data that Graves used in his paper, the IAM Handwriting database. You need to ask them for permission to download it, so I couldn’t put the data on the github, and you will need to unzip the file lineStrokes-all.tar.gz into the data subdirectory yourself if you want to train the network.

Inside the IAM Database is around 13000 different lines of handwriting examples recorded from a digitised pen stroke data, and the data is recorded in xml format comprised of a set of strokes, and each stroke is a set of connecting points drawn by the pen without lifting from the paper. Below are a few examples of what the data look like. I have written some code to extract all this data and draw them out interactively in an IPython session:

    %run -i utils.py
    data_loader = DataLoader()
    for i in range(5):
      draw_strokes(random.choice(data_loader.raw_data))

iam_samples2

Some of the data is noisy and can contain mistakes and human user manually crossing out writing after making mistakes. In Graves’ paper he put in some filtering to detect bad examples, but in this demo I take all of the data in. The only thing I did was to scale the dimentions of the data that can be more compatible with the outputs of a neural net, and I capped the magnitude of the gap distance from one stroke to another.

Each training example can be interpreted as a bunch of points that comprise individual strokes:

    sample = random.choice(data_loader.raw_data)
    draw_strokes_random_color(sample, per_stroke_mode = False)
    draw_strokes_random_color(sample)

point_color

stroke_color

As we can see from the plots above, where we randomise the colors of each point or stroke, each example is comprised of a set of connected points, and each set of connected points forms a stroke. We will model the data as a series of vectors containing the step size in x and y directions to the next point, and an end-of-stroke value that is either 0 or 1, denoting whether the next point is still part of the current stroke, or if we need to lift the pen up and start a new stroke.

Model Description

Rather than have our network predict the exact location of the next point, and whether current point is the end-of-stroke, we will use the MDN approach and have the network output a set of parameters of a joint-probability distribution for the relative location of the next location (\Delta x, \Delta y), along with a simple Bernoulli “coin-flip” distribution for end-of-stroke probability. The reason why the direct prediction approach won’t work is that there are too many different states and context of where the next stroke will be, and all we would be doing is to predict the on average expected location of the next point, which is likely a trivial outcome, like a noisy line drifting towards the right.

Like in the inverted sinusoidal data in the previous post about MDNs, we want to model different potential states and context in the data and be able to generate a plausible distribution of the next point conditional to the entire historical set of points where we can then draw from to generate our handwriting examples. Therefore in this MDN, the inputs into the network would be the most recent relative stroke movement, the most recent end-of-stroke signal, and the previous hidden state of the network, while the output of the network could be a set of values that parametrise the probability distribution of next stroke movement and the next end-of-stroke signal.

Once we have trained the network to generate accurate distributions of the future given the historical past, we can just sample from the probability distribution to generate our handwriting sample. It is as of the neural network is dreaming up some handwriting example by feeding back to itself it’s previous generated stroke. In our demo, we used a 2-layer stacked basic-LSTM network (no peephole connections) with 256 nodes in each layer.

Our model for the probability distribution for the future stroke vector would be a joint 2D normal mixture distribution, charactersised as a probabilistic weighted sum of 2D Normal distributions, each with their own means, and covariances. We used 20 mixtures in our demo, to be consistent with Graves’ paper, but we found actually that even 5-10 mixtures worked well enough, however the extra number of mixtures didn’t really cause a huge drop in algorithm performance and didn’t really increase the total size of the network, as most of the weights were in the LSTM layers, so we just kept it at 20. If you want to experiment with different number of nodes, node types (RNN, GRU, etc), enable LSTM peep-hole connections, number of mixture distributions, different DropOut probabilities – all of this can be done by setting different flags when running train.py.

In total, we would demand 121 output values from our network, Z, from the MDN to infer our distribution. One of these values would be used as the end-of-stroke probability, 20 values would define the probability of each mixture, while the remaining 100 values constitute 20 sets of 2D Normal distribution parameters. As the output values are real numbers that may not be bounded, we would perform a transform to get to the values in parameter space:

e = \frac{1}{1+ \exp(Z_{0})}

\Pi_{k} = \frac{\exp(Z_{k})}{\sum_{i=1}^{20} \exp(Z_{i})}

\mu_{1} = Z_{21\rightarrow 40}, \mu_{2} = Z_{41\rightarrow 60}

\sigma_{1} = \exp( Z_{ 61 \rightarrow 80} ), \sigma_{2} = \exp( Z_{ 81\rightarrow 100} )

\rho =\tanh( Z_{ 101 \rightarrow 120} )

Using these transformations, like in the previous MDN example, the \Pi_kvalues undergo the softmax operator so they sum up to one. The end-of-stroke probability eis also bound between 0 and 1. The standard deviation parameters will be strictly positive, and the correlation between both coordinates will be between -1 and 1 after the exponential and hyperbolic tangent transformations have been applied. After the parameters have been obtained, the probability density of the next stroke will be defined as:

  P(X=x) = \sum_{k=0}^{K-1} \Pi_{k} \Phi(x_{1}, x_{2}, \mu_{1},\mu_{2}, \sigma_{1},\sigma_{2}, \rho) P(X_{eos}=x_{eos})

Unlike in the previous example, where all the weights were just stored in a global tensor variable type, as this task is much more involved and has more moving parts, we like the model to be nicely packaged in a class type to have an easier to use object oriented interface. We also incorporated DropOut for each of the output layers of the LSTM layers to regularise training so it tends not to overfit, but we didn’t apply dropout to the input layer, as the sequential and path dependent nature of writing meant it was important not to miss the end of strokes. We found DropOut to be fairly effectve in this task, and TensorFlow makes it relatively easy to ‘drop’ in this feature. TensorFlow’s rnn_cell module makes it fairly easy to implement stacked RNNs with DropOut. For example, below is all it takes to construct the two-level LSTM layers used in our network with DropOut:

    cell = rnn_cell.BasicLSTMCell(256)
    cell = rnn_cell.MultiRNNCell([cell] * 2)
    cell = rnn_cell.DropoutWrapper(cell, output_keep_prob = 0.8)

For training, we would apply cross entropy on the maximum likelyhood estimate of the entire generated sequence, like before. Although efficient closed form derivations for the gradients are available, we rely on TensorFlow to calculate the gradients automatically via its symbolic engine. Gradient clipping of 10.0 is used to avoid gradients blowing up when backpropaging the derivatives back through time.

For the exact implementation details, please refer to model.py.

The Nitty Gritty Details of the Mini-Batch Training

I have written the model to train the above network to either preprocessed IAM Handwriting data, or if the preprocessed data hasn’t been built, the module would build a cPickle preprocess database from the raw xml files. The tricky bit about the training is that we want to use minibatches, but they have to be the same length to be efficient I didn’t want to concatenate every single stroke together and train on a chopped up same-size set of stroke data, since there would be a lot of unnatural gaps between the lines that are duct-taped together and we would be training on this artificial noise.

Instead, what I ended up doing is to pick a sequence length, 300 points in this case for training, and throw away training data sequences that had less than 300 points (which were not that much anyways, as most of the training data had between 300-2000 points). Afterwards, while creating the minibatches, I would sample a random continous portion of 300 points from each sample. For example, if a training sample had 400 datapoints, a sample to be inserted into the minibatch will be anywhere from [0:300] to [100:400], so this actually may help generalize the data even more (like distorting MNIST images to create more datapoints). In addition, for samples that have much more than 300 points, say 1500 points, I would use that sample on average 5 times more than another sample with only ~ 300-400 points, to ensure that larger samples are not undertrained on. The whole training process lasted for roughly 30 epochs, and took about half day on a macbook pro without using a GPU.

Generating Samples from the Network

So after the training, our network can generate samples and save them as .svg files. I figured out how to display them in IPython as well and wrote some modules to automatically display some samples.

When we sample the handwriting sequence, we first start off by emptying the states of the LSTM network, and passing into the network an initial input.

    prev_x = np.zeros((1, 1, 3), dtype=np.float32)
    prev_x[0, 0, 2] = 1 # initially, we want to see beginning of new stroke
    prev_state = sess.run(self.cell.zero_state(1, tf.float32))

Initially the initial input is just a zero vector, but the end-of-stroke signal is turned on, signalling to the network that the next point it produces will be the start of a new stroke, rather than a continuation of a an existing stroke, since I thought that can be more interesting to get more diverse starting points.

After the initial input, and zero state is passed into the network, we would get a set of parameters from the output of the network, and this set of params will be the parameters of a mixture 2d gaussian distribution that defines the probability distribution of where the next point will be located, and also one more parameter that defines the probability that the next point will be the start of yet another stroke.

We randomly sample a set of values from this distribution, and then added the point to an .svg file we are building during this process, and also record the state of the network. Afterwards, we repeat the loop and feed in the sampled point and network state back in as inputs, to get another probability distribution to sample from for the next point, and we repeat until we get 800 points, or however many the user specifies. Below is the python pseudocode for the sampling process.

    strokes = np.zeros((num, 3), dtype=np.float32)

    for i in xrange(num):

        # get the model parameters from the network
        feed = {model.input_data: prev_x, model.initial_state:prev_state}
        [model_params, next_state] = sess.run([model.model_params, model.final_state],feed)

        # sample whether we want to end the stroke
        eos = sample_eos(model_params.eos)

        # sample which mixture to use
        idx = sample_mixture_index(random.random(), model_params.pi)

        # sample location of the next stroke
        next_x1, next_x2 = sample_gaussian_2d(model_params, idx)

        # put the current generated stroke as the next input
        prev_x = [next_x1, next_x2, eos]

        # record the generated stroke, since we want to draw it later
        strokes[i,:] = prev_x

        # save the current RNN's state to feed it back in next time
        prev_state = next_state

Sample Results

Here is how we use the code interactively in IPython to generate and plot a few examples, of say, 800 points:

    %run -i sample.py
    [strokes, params] = model.sample(sess, 800)
    draw_strokes_random_color(strokes)

generated_examples_0

generated_examples_1

These results don’t look that bad! The pen seems to switch sometimes between printing and cursive writing, as if were written by a crazy lunatic trapped in a room high up in a deserted castle with multiple delusional personalities. We all know one or two people like that …

In addition to saving the sampled points, we also saved the history of the probabiilty distribution parameters for further visualization of what is really going on. To get an idea, in the sample below, we plot out the generated sample, but also two more plots.

    draw_strokes_pdf(strokes, params)
    draw_strokes_eos_weighted(strokes, params)

full_set2

The second and third plots are visualization of the probability distribution of the writing process as the network is dreaming up the sample. In the second plot, we plot out the actual sampled path, plus for every point, the generated probability density of the next point. In the third plot, we overlap the sampled path with the end-of-stroke probability of each point. We can see indeed that when we get to near the end of each stroke, naturally the probability of the end-of-stroke signal increases – the line becomes darker. In addition, we see in the second plot that when the network is writing a continous stroke, the network is fairly confident in the location of the next point, as evident by the small red dots implying a dense distribution in a small target area. Meanwhile, as it gets to near the end of a stroke, the probability density for the next point becomes more sparse and sometimes the network will generate a larger variety of possible next-spots, as evident by the larger, more transparent blobs on the plot.

Here is the graph again for another two set of samples of longer sequence length:

full_set1

full_set0

To Do Next …

The next thing to do can be trying to implement handwriting synthesis in Graves’ paper, which may incorporate bits of character embedding used in char-rnn. He found that a combination of character prediction, along with stroke prediction to be key to generating synthetic handwriting sequences that look natural, as the network needs to learn how certain strokes for a certain character flow into another character sequence.

Another interesting work is to incorporate Generative Adverserial Net approaches to recurrent network. Train a network to discriminate between fake handwriting and real one, and another network to generate fake handwriting to fool the discriminator network. This can be quite tough to do for RNNs though, although we must try anyways! Some are thinking is that generative moment matching networks can perhaps be a better approach for generative RNNs as opposed to GANs.