This is a TensorFlow implementation of Neural Turing Machine. The code is inspired on the DNC implementation of Mostafa-Samir and it therefore follows the same structure and organization. However, some of the content addressing functions were adapted from the NTM implementation from carpedm20.
The code is designed to deal with variable length inputs in a more efficient way than the reference code provided by carpedm20.
The implementation currently supports only the copy task described in the paper.
[Important Notes]:
-
For training efficiency, I replaced the circular convolution code proposed by carpedm20 in order to use some already optimized TensorFlow functions. This code, however, was hard-coded for batch_sizes=1 and shift_range = 1. If you want to use other values, remember to either modify the proposed function (apply_conv_shift - memory.py) or to replace the function by the commented code.batch_size
can be other than 1 for training now, butshift_range
is still hard coded for 1. -
In a similar way, the Hamming distance computation used in train.py for performance visualization is currently hard-coded for batch_sizes=1. Modifying this, however, should be much simpler...
The model was trained and tested on a machine with:
- Intel(R) Core(TM) i7-4790K CPU @ 4.00GHz
- 16GB RAM
- GeForce 1080Ti (performance on such small network as this is actually better on my CPU)
- Linux Mint 20
- TensorFlow 2.4.0
- Python 3.8.5
To train a copy task:
python train.py --iterations=100000
You can generate similar results or play with the sequence length in the visualization notebook.
It is really interesting to check the memory location map, where you can see that the network learns to address the memory in a sequential way and in the same order they were written into.
As I mentioned, this model was trained with a maximum sequence length of 10, but it can generalize to longer sequences (in this case sequence length = 30).
I trained the model with a recurrent controller, a maximum sequence length of 10 and an input size of 8 bits + 2 flags (start and finish) for 100000 iterations. The learning process can be seen below: