AutoCorrect / Spell Check using Attention Model in Python!

Piyush M
4 min readJun 6, 2021
Photo by Glenn Carstens-Peters on Unsplash

Introduction

This is part 2 of our Autocorrect/spell check blog. Click here to look at the previous blog. I would recommend you to read the previous blog before you read this. Today we are going to implement Attention models to further improve our accuracy. Want to check out the code now? Check out my Github link.

Data Preparation

Today we shall use the English to Spanish translation dataset as there are not any standardized data for autocorrect.

# Download the filepath_to_zip = tf.keras.utils.get_file('spa-eng.zip', origin='http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip',extract=True)path_to_file = os.path.dirname(path_to_zip)+"/spa-eng/spa.txt"

We initially just preprocess the input sentence to remove all punctuations and return only the English alphabets.

en_sentence = u"May I borrow this book?"print(preprocess_sentence(en_sentence))#may i borrow this book

Next, we prepare our input data. The process is quite simple, we take the English text from the corpus (ignoring the Spanish text), add some noise into the text and consider that as the input to our model. The output is the unchanged English text from the corpus.

Kindly go through the previous blog to fully understand the above piece of code. The link to the previous blog can be found here.

We create a function that reads the data, preprocesses it, and adds noise to it. Here we have taken 80,000 samples from the dataset to not exceed the memory capacity. If you have a GPU with a higher memory capacity, you could play around with the number of samples.

Input Token Index:{' ': 0,
'a': 1,
'b': 2,
'c': 3,
'd': 4,
'e': 5,
'f': 6,
'g': 7,
'h': 8,
'i': 9,
'j': 10,
'k': 11,
'l': 12,
'm': 13,
'n': 14,
'o': 15,
'p': 16,
'q': 17,
'r': 18,
's': 19,
't': 20,
'u': 21,
'v': 22,
'w': 23,
'x': 24,
'y': 25,
'z': 26}

We define a function that takes in the data and tokenizes it. We are performing a character level tokenizer on our dataset since the number of embeddings will be constant. If we were to use word level, the size of the embedding layer would be too high (for every small error in a word, a new token would be created). Whereas in character level, there is a definite set of English alphabets (i.e 26) so only 26 tokens ever created.

We have an 80/20 train and test split. Finally, we have our data “input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val” ready which can be used to train our model.

Model

Attention models, or attention mechanisms, are input processing techniques for neural networks that allows the network to focus on specific aspects of a complex input, one at a time until the entire dataset is categorized

Source : Andrew NG course on Coursera

To learn about Attention models in detail kindly go through the below blog
1. https://deepai.org/machine-learning-glossary-and-terms/attention-modelshttps://blog.floydhub.com/attention-mechanism/
2. https://arxiv.org/pdf/1409.0473.pdf

This tutorial uses Bahdanau attention for the encoder. Let’s decide on notation before writing the simplified form:

  • FC = Fully connected (dense) layer
  • EO = Encoder output
  • H = hidden state
  • X = input to the decoder

And the pseudo-code:

  • score = FC(tanh(FC(EO) + FC(H)))
  • attention weights = softmax(score, axis = 1). Softmax by default is applied on the last axis but here we want to apply it on the 1st axis, since the shape of score is (batch_size, max_length, hidden_size). Max_length is the length of our input. Since we are trying to assign a weight to each input, softmax should be applied on that axis.
  • context vector = sum(attention weights * EO, axis = 1). Same reason as above for choosing axis as 1.
  • embedding output = The input to the decoder X is passed through an embedding layer.
  • merged vector = concat(embedding output, context vector)
  • This merged vector is then given to the GRU

Encoder

Bahdanau Attention

Decoder

We compile our model with optimizer as “adam” and loss as “sparse_categorial_crossentropy”. Model is trained for 2 epochs with a batch_size of 64. We save a checkpoint incase the execution breaks.

Training

  1. Pass the input through the encoder which return encoder output and the encoder hidden state.
  2. The encoder output, encoder hidden state and the decoder input (which is the start token) is passed to the decoder.
  3. The decoder returns the predictions and the decoder hidden state.
  4. The decoder hidden state is then passed back into the model and the predictions are used to calculate the loss.
  5. Use teacher forcing to decide the next input to the decoder.
  6. Teacher forcing is the technique where the target word is passed as the next input to the decoder.
  7. The final step is to calculate the gradients and apply it to the optimizer and backpropagate.

Inference

The evaluate function is similar to the training loop, except we don’t use teacher forcing here. The input to the decoder at each time step is its previous predictions along with the hidden state and the encoder output.

Here the model keeps predicting the next character until the EOS token is received or the “max_lenght_targ” is exceeded. To perform the inference on our text, we can run the following snippet.

autocorrect(u'im jfst wacching teaevision')im just watching television

Download the whole code or run the code on Google Colab using this link:
https://github.com/piyush0511/SpellChecker-AutoCorrect/blob/main/Autocorrect-AttentionModel.ipynb

We can see here that our model can have variable input and variable output, very few models can do that. There are a lot of further improvements that could be made with our model to further improve the performance.

Connect with me on Linkedln

References

  1. https://arxiv.org/pdf/1409.0473.pdf
  2. https://towardsdatascience.com/attention-networks-c735befb5e9f
  3. https://www.tensorflow.org/text/tutorials/nmt_with_attention
  4. https://towardsdatascience.com/deep-learning-autocorrect-product-and-technical-overview-1c219cee0698

Thank you for reading,
Piyush

--

--