40_examples

Fake Kanji characters generated from a LSTM-Mixture Density Network in SVG Format using sketch-rnn.
Code available on GitHub.

Update June 2018: An interactive browser demo is now available online: otoro.net/kanji-rnn. This demo uses the more recent Sketch-RNN model (Javascript, TensorFlow), trained on a more fine-tuned dataset.

Note: Kanji (漢字) is the Japanese term for Chinese Characters. I use them interchangeably depending on the context.

Introduction

This is the third post in a series of blog posts logging my experiments with with TensorFlow. For a better understanding, feel free to read the first post on Mixture Density Networks (MDNs), and the second post on Long Short Term Memory networks combined with MDNs to generate fake handwriting examples as in Alex Graves’ legendary sequence-generating paper.

We will be modifying and extending Graves’ approach to get LSTM + MDN to generate fake Chinese characters in vector format. The main difference in approach is to model an addition end of content state in addition to the existing end of stroke state, and applying selective gradient boosting on these less likely pen-states.

Motivation

As a child growing up in a mostly English speaking country, my parents would force me to attend these dreadful Saturday morning classes where I was to be taught Chinese. There would be these dictation tests where the students have to write out full passages of memorised Chinese text from a textbook, usually indirectly exposing us to Confucian moral values. We would have to spend a lot of time during the weeknights memorising passages to prepare for the test on the following Saturday. A score less than perfection is frowned upon. This would go on for years. I still have nightmares about those dictation tests. I think that’s how most children learn Chinese as well via this rote learning method around the world. Maybe in some sense, Chinese language education resembles how LSTM’s are trained to reproduce sequences from training examples.


Example of child learning Chinese characters found randomly on the Internet.

Example of a child’s struggles at learning Chinese characters
(source: random link found on the web)


From personal experience, being able to write Chinese characters is very different than being able to read Chinese characters. Kanji writing proficiency has been on the decline for decades. It has been a phenomenon in the past few decades now, starting in Japan, and now probably the rest of Asia where most people type Asian text via a phonetics-based input method (kana input for Japanese and pinyin for Chinese), and then visually choose the desired character from a bunch of likely candidates from the screen. The problem arises when people sit down and write a New Year’s Card to be mailed by post to their old Sensei and suddenly realise they forgot how to write Kanji. I am also guilty of this – even though I read a lot of Chinese and Japanese content in my everyday life, I struggle to write Chinese characters. What we notice is that while we can definitely read and recognise the characters we are able to write, the converse is certainly not true.

Similarly in Machine Learning, many problems start off being classification tasks. What is this digit? Is this picture a dog, or a house? Is this transaction a fraud? Should we lend money to this dude applying for a mortgage? Man or Woman? How old is she?

While these tasks are very useful, I think a more interesting task is to generate data, which I view as an extension to classifying data. Like how being able to write a Chinese character demonstrate more understanding than merely knowing how to read that character, I think being able to generate content is also key to understanding that content. Being able generate a picture of a 22 year old attractive lady is much more impressive than merely being able to estimate that the this woman is likely around 22 years of age.

An example of a generative task is the translation machines developed to translate English into another language in real time. Generative art and music has been increasingly popular. Recently, there has been work on using techniques such as generative adversarial networks (GANs) to generate bitmap pictures of fake images that look like real ones, like fake cats, fake faces, fake bedrooms and even fake anime characters, and to me, those problems are a lot more exciting to work on, and a natural extension to classification problems.

What interests me more though, is the ability to generate vectorised content. I think a lot of useful content are best represented in vector format, as opposed to rasterised bitmap images. For example, digitised pen sketches, CAD designs, geo-location tag data like where we bike on Strava, scientific experimental data, and of course handwriting, are better expressed as vectorised data.


Example of stroke order from KanjiVG

Examples from KanjiVG Stroke Order Database.


I think fonts and handwriting is best expressed with vectors, rather than bitmaps. As someone interested in design, I’m a bit of a font geek, and I really appreciate well designed, beautiful TrueType fonts, which look stunning regardless of the size they are displayed. I think pen strokes recorded from digitised writing are much more naturally suited to be represented as vector data, and I’m a big fan of online handwriting data compared to offline data whenever they are made available.

In this blog post, I will describe how to train a recurrent neural network to generate fake, but plausible Chinese characters, in vector .svg format. For the training, we will expose the network to real examples of Chinese characters from a stroke-order database of Kanji, so the network will also need to write made-up Chinese characters with a plausible stroke order as well.


Game to write Kanji with perfection.

Game that teaches you to write perfect Kanji.


Stroke order is very important to Japanese culture, in a society where the process matters just as much as the end result. Some calligraphists take stroke order very seriously, and will probably explode if they see someone writing a Kanji with incorrect stroke order. Japanese companies even created video games to allow one to learn correct stroke order of Kanji, like the one pictured above. Only vectorised data can really model this stroke order, and capture the pure essence of Kanji, and probably some part of Asian culture. Rasterised Kanji data is the equivalent of watching the English dub of a popular Anime, two years after it got released in Tokyo, in low quality RealMedia streaming file format. As such, we will want our generative recurrent neural net to learn to write Kanji characters obeying, and respecting the proper stroke order to maintain order in the universe.

Background

Our generative prediction model will follow the same framework outlined in Graves’ paper where he demonstrated both generative text, and generative handwriting. Karpathy’s post and char-rnn implementation has some fantastic examples on how this framework is used to generate data that is represented as text.


state_diagram

Generative Sequence Model Framework


In the text-generation example, assuming we have a model that has been pre-trained already, we feed in an initial random character into the model that with an initially empty state. The model will use the state’s information along with the current input, to generate a probability distribution for the next character. That distribution will be sampled randomly (possibly we distort the sampling process by applying temperature), to obtain a prediction for the next character. The sampled character will be placed back as the next input, as well as the current internal state of the model.

A simple model that fits this framework is the basic N-GRAM character modelling method. In N-GRAM, all we do is keep a record of frequencies of the previous N characters and use the historical table of frequencies as the probability distribution we can draw to generate the next character.

This framework can also be represented by a recurrent neural network, where the states are the hidden states of recurrent LSTM nodes, and the output values of the network can be converted into a discrete probability distribution via applying softmax layer to the outputs. To train the weights of the neural network, we need a way to compare the predicted distribution, to the actual distribution of the training data. What is usually done is that cross-entropy loss function is usually applied, to compare the model’s predicted probabilities after the softmax layer, with the actual data of the entire sequence generated.

This has been done in Graves’ sequence generation paper and implemented as char-rnn by Karpathy. char-rnn has been used successfully to generate not only Shakespeare’s text, but also bizarre examples such as Linux source code, LaTeX documents, wikipedia formatted xml articles, and music scores.

sketch-rnn, the char-rnn for svg training data

I wanted to create a char-rnn like tool, but for learning to draw sketches, rather than learning to generate character sequences. SVG data is readily available on the web, although obviously not as easy to obtain as text data. In the end, I created a tool called sketch-rnn that would attempt to learn some structure from a large collection of related .svg files, and be able to generate and dream up new vectorised drawings that is similar to the training set. Just as how char-rnn can take Donald Trump quotes to generate hypothetical Donald Trump wisdom, I wanted to be able to feed in a large collection of .svg pictures of cats and have an algorithm come up with new vectorised pictures of cats. It was difficult to obtain .svg pictures of cats in sufficient quantities, while it was quite easy to obtain .svg files for Chinese characters, so in the end this turned into an experiment to generate fake Kanji.

sketch-rnn follows the same principle of the previous handwriting generation demo blog post. Each drawing is modelled by a pen-stroke-like data, where each step of data contains an offset in x, and y axis, as well as whether the pen is on the paper or off the paper, and a small line is drawn from the previous step to the current step if the pen is down. The neural network will have to come up with a probability distribution for the next step. Unlike the character generation example where the pdf is just a pile of discrete probabilities for each possible character, we need a continuous distribution for the offsets in x and y direction, along with a probability that the pen will be lifted above the paper in the next step (this is called the end-of-stroke probability). The distribution used will be a Mixture Gaussian Distribution to estimate the pdf of the next offsets to x and y pen locations. This is the method, called Mixture Density Networks, used to generate fake handwriting in the previous blog post.


Example of Mixture 2D Gaussian Densities

Example of Mixture 2D Gaussian Densities


Above is an example of the Mixture Gaussian Densities we use to draw strokes of Chinese characters. As it is drawing the strokes shown via the black dots with connected lines, the LSTM+MDN algorithm will continually estimate the probability distribution of where the next dot will take place. The distribution is modelled as a mixture gaussian distribution. What this means is the next location is comprised of a mixture of many different locations (you see in the different shaded red ellipse shapes), and each location itself is a 2-dimensional joint gaussian distribution for the x and y offsets each with its own 2×2 covariance matrix.

What we need to do in addition to the stroke location distribution and end-of-stroke probability, is to model also the probability that we have finished writing this Chinese character, or, that we are finished drawing this sketch entirely (we will call this the end-of-char, or end-of-content, ‘eoc’ probability). Originally, I just added another similar variable as the end-of-stroke probability as a seperate random variable to model up, but I found that this didn’t work well in practice. We know that whenever we have an end-of-char event, that event will also be an end-of-stroke event, so there is some mutual information value that we are discounting.

It took me a while to figure out a better way to model both signals, and in the end what I did was I modelled the state of the pen, as a set of discrete states, which can be modelled in the neural network as a softmax layer. The pen-state can be one of (end-of-stroke, end-of-char, or pen-is-down), and for each step, the model will have to assign probabilities to these three states. This seems like a very elegant approach, and is similar to using char-rnn to model the 3 pen states, in addition to using a mixture distribution to model the offsets of x and y.

Processing Training Data

I used the KanjiVG database created by these really nice people who created a fantastic open source Japanese language learning tool. As mentioned earlier, the stroke orders matter when writing a Kanji, and even if one writes a Kanji in an incorrect stroke order and arrives at a final character that resembles the correct character, it is still an incorrect Kanji. The KanjiVG github repo contains around 11000 Kanji .svg files, and the path elements of each .svg files are ordered in such a way to respect the official Japanese Kanji stroke order. As an aside, there are minor differences between some Japanese Kanji stroke order and Chinese character stroke orders, and this probably adds to existing geopolitical tensions near the South China Sea.

In sketch-rnn, the SketchLoader class will read in all the .svg files located inside the data subdirectory, and then slice and dice up all the line and path elements into smaller lines that is more fitting to train our data on. Please refer to the code for the specifics on the granuarity and path-to-lines conversion, as it can all be customised.

Below are some training examples of training characters extracted from KanjiVG:


Sample Kanji from KanjiVG Stroke database.
Training examples from KanjiVG dataset.
Different colours show different strokes of each character.

The SketchLoader object will then dump all the lines extracted from the .svg files into an array of stroke arrays into a cPickle binary file for future training use. It will also generate the minibatch sets used for the training process later.

Boosting Gradients for Less Likely Events

The initial results were a bit dissapointing, because the pen would just ramble on and not move to the next character, despite generating realistic shapes of individual strokes. What I noticed is that the algorithm generally underestimated, or simply ignored the end-of-character probability in the process of minimizing the log likehood loss function, and I think this is because the end-of-char events don’t happen often, and occurs in the order of 1 out of 100 steps.

This problem did not occur in Graves’ handwriting generation problem. In the English handwriting training data, we didn’t have to model the end-of-training signal, and can have the computer ramble on until the end of time. In this problem, we actually need to train the algorithm to know exactly when to stop writing after it thinks it has finished writing a complete Chinese character.

Example of an initial training result where the computer rambles on, if the boosting method outlined in this section hasn’t been applied:


Example of rambling.  Computer doesn't know when to stop.

Example of not knowing when to stop.


The way I got around this problem was to devise a way to increase the error, and hence the gradients of the datapoints that occur when the end-of-char signal occurs, and modified the loss function accordingly to include these increased weightings.

Generative Model:


 P(X=x,Y=y,M=m) = P(X=x,Y=y) P(M=m) </p>

 P(X=x,Y=y) = \sum_{k=0}^{K-1} \Pi_{k} \Phi(x, y, \mu_{x},\mu_{y}, \sigma_{x},\sigma_{y}, \rho)

 P(M=m) = z_{m} = \begin{cases} p_{stay} & \text{pen stays on the paper} \\ p_{eos} & \text{pen is to be lifted up} \\ p_{eoc} & \text{stop drawing after this point} \end{cases}

Where x, y, m denotes random values of x-offsets, y-offsets, and pen-state for the next step of the sketch. x, y is modelled with a mixture 2D gaussian distribution from the MDN output, while m will be modelled as a softmax one-hot output z_{m}.

Gradient-boosted Loss Function for each step:


 Loss(x, y, m) = - Log \left( \sum_{k=0}^{K-1} \Pi_{k} \Phi(x, y, \mu_{x},\mu_{y}, \sigma_{x},\sigma_{y}, \rho) \right) - w(m) Log \left( z_{m} \right) </p>

 w(m) = \begin{cases} 1 & \text{pen stays on the paper} \\ 10 & \text{pen is to be lifted up} \\ 100 & \text{stop drawing after this point} \end{cases}

w(m) attempts to boost the loss function for less likely events in the data.

I chose a factor of 10 for end-of-stroke points and 100 for end-of-character points, to complement the default unit factor for pen-down points, so that the weighting can somewhat complement the lower probability of those signal events. I found that this method is very effective, and when combined with some example diversity and shuffling tricks I will describe in the next session, the end results improved substantially.

I’m pretty sure someone has thought of this weighted boosting method before for improbable events, as it is quite an obvious approach. I looked it up on the web and couldn’t find anything though (the closest was this). If you can’t find it, well, you heard of it here first.

Diversify Examples – Shuffled Minibatch Creation, Randomised Scaling

At the beginning of each epoch, we shuffle the order of the characters in the training data. The stroke order database was created in such a way such that similar groups of characters were next to each other, and I wanted each batch to be comprised of a better representive set of training examples in the Chinese language. I want to increase the diversity in mini batches, otherwise the algorithm will just be spending its time only learning “fish-kanji” () for a very long time, have the gradients all distort it’s perception of the universe and think everything is a fish, and then suddenly relearn bird-kanji for a while()and get its mind all messed up. It’ll probably get messed up and confused regardless when I tries to learn the Ultimate Character:

One limitation as well is that the stroke order dataset only had around 10000 examples of Kanji writing, and I felt we needed to manufacture more artificial data for the training. So I took a trick I learned from distorting MNIST training examples to create more training data. Each minibatch is randomly scaled to anywhere between +/- 30% of original size. The structure of the data, as everything is just an offset in space, makes this scaling very easy to do, by just multiplying the entire matrix by one factor. I do this distortion on the fly rather than precomputing extra examples and storing them, and just increase the epoch count. What I could possibly do as well is to distort the x and y axis by different scaling factors, which I haven’t done yet at the time of writing, but it’s really easy to modify the tool to do that with just an extra line of code.

Finally, each minibatch starts with the beginning of a character, not starting in the middle, as I want the algorithm to be exposed to learn full structures. I don’t want the algo to start training in the middle of a basic pattern, like (). So the next batch will just skip to the start of the next character sample. As a consequence, each epoch may not have the exact same number of minibatches due to this skipping, but should be close.

Fine Tuning

The beginning of the experiment was quite frustrating, as I only got gibberish, and the algorithm rambles on and on to write a giant web of lines like a crazy mad man. It took me a while to come up with the idea of using softmax to model both end-of-stroke and end-of-character, and also a lot of caffine to come up with the gradient boosting idea.

The tricks we used, boosting grads for less likely events, and diversifying training examples have improved the quality of results. That being said though, the final results still contain some bad results occasionally. As a final filter, I threw out results that exceeded the size of the writing area and had them start over until they get it all within a specified box area.

Model Setup

I included in sketch-rnn github repo a smaller pre-trained net so if you like, you can try to run sketch-rnn on your machine by just running python sample.py.

The smaller pre-trained net generates 24 gaussian mixture distributions for each time step, and uses 2 layers of 256 LSTM nodes, with dropout keep probability of 80% employed at the outputs of each layer.

I scaled the data down in size by a factor of 15. This is an interesting problem, as the typical training examples have sizes around 80 to 160 units on each axis. I found a good rule of thumb is to scale the data down so that the average dimension of the data is in the order of 10×10, and typically for Chinese characters, the offsets of each successful step is in the order of 1×1 size.

Using minibatches of 50-100 examples seemed to work well. I tried to have a relatively larger initial learning rate, and have that learning rate decrease proportionally by 1% after each subsequent epoch. Sometimes having a learning rate too large will crash the training, and the part of the training that crashes is to do with the estimation of end-of-character likelihood. It is a bit tricky when we need to estimate the probability of unlikely events using the gradient boosting method above, and that may lead to numerical instability.

Example Results


Fake Kanji generated via sample.py in sketch-rnn.
Fake Kanji generated using sketch-rnn.

I’m quite happy with the results. sketch-rnn was able to generate a variety of Kanji that does not exist, but resembles somewhat the way Kanji are supposed to be writen. Many radicals and basic parts of Kanji are placed and configured in locations that makes sense in terms of forming the structure of a Kanji. It seems to resemble a child struggling to pass a Chinese dictation test and trying to wing it by desperately making up answers.

Some Interesting Examples


six stoned chicks
six stoned ladies

lonely ghost
lonely ghost

wooden food
wooden food

urban sheep
urban sheep

wood pecker
wood pecker

stop eating lambs
stop eating lambs

birding
bird hunting

educated horse
educated horse

listening bird
listening bird

listerine
listerine

new type of wooden house
new type of wooden house

lucky horse
lucky horse

Other notable examples below I couldn’t really describe, could you?

notable1notable2notable3notable4

notable5notable6notable7notable8

notable9notable12notable13notable14

Some examples reminds me of some Cantonese profanity converted to new Chinese characters (like 𨳊 or 撚).

Future Work

Cursive Chinese Writing

I have also looked at this online handwriting database by CASIA. It will be really easy to apply this algorithm on that data and possibly train the recurrent net to generate fake cursive Chinese handwriting. Personally I don’t find that as interesting as this stroke-based dataset, because I wanted to see if the algorithm can generate distinct structures inside Chinese characters, rather than squigly handwritten characters that have already been done with the previous handwriting example.


廣 -> 広 -> 广

As an additional rant, as a designer, I’m not a big fan of post 1956 Simplied Chinese as I feel that the PRC has done too much to simplify Chinese in an Orwellian New Speak sense of the language. Compare original beautiful traditional Chinese characters, their accepted simplied form pre-1900 for handwriting across Asia, and post-1956 Simplied Chinese forms in the above example.

Sketches of everyday things

It would be interesting to also expose sketch-rnn to this TU Berlin sketch data from (Mathias Eitz, James Hays and Marc Alexa, 2012) and see what it draws after learning from the data. I have a feeling it won’t work that well though, since the sketches are quite diverse and may not contain similar statistical structures. If we had example sketches of 10000 Houses, all with similar number of strokes, complexity, and dimensions, it might work well. But if we have a database of Tables, Rabbits, Fish, Apples, and Buildings I think it will be too much for this model to handle. One interesting thing can be to take existing algorithms that convert rasterised images into vectorised .svg format and run sketch-rnn on them.


Elephant from Berlin Sketch Database

Elephant from TU Berlin Sketch Database


Extension to method

In the future I would also want to work on more powerful approaches to understanding patterns beyond LSTM + MDN, which is just an extension of LSTM + Softmax. Recent work on Variational Autoencoder, Generative Moment Matching Nets, or the much hyped BPL could be a lot more expressive and powerful. I’m also thinking of GAN approaches may work on recurrent nets, although my feeling is it will be very difficult to train LSTM GAN’s.

I’m also thinking of ways to get more generative power with smaller nets. Eventually I want to be able to use these trained nets inside the web browser and have client-side JS run generative demos that may interact with a user’s sketching activity in real time, which I think would be super cool. If anyone has any ideas on how to compress LSTM nets effectively into small JSON files, let’s discuss.

Updates:
Discussion about this work at University of Pennsylvania's Linguistic Group.
This article has also been translated into Simplifed Chinese by @weakish.

Citation

If you find this work useful, please cite it as:

@article{ha2015chinese,
  title   = "Recurrent Net Dreams Up Fake Chinese Characters in Vector Format with TensorFlow",
  author  = "Ha, David",
  journal = "blog.otoro.net",
  year    = "2015",
  url     = "https://blog.otoro.net/2015/12/28/recurrent-net-dreams-up-fake-chinese-characters-in-vector-format-with-tensorflow/"
}