Sequence-to-Sequence with Attention in TensorFlow¶

A tutorial on using attention in a sequence-to-sequence network.

by Incomplete

Based on Effective Approaches to Attention-based Neural Machine Translation, Minh-Thang Luong, Hieu Pham, Christopher D. Manning and the TensorFlow source code.

It can be divided into two parts:

  • The concept, the data flow and the math
  • A concrete model that translates number pronunciations into digits, e.g. "one hundred and twenty seven" -> "127"

Summay of Notations¶

This is a reference of the notations used, most of them are the same with the original paper.

Notation Meaning
$h_t$ target hidden state (i.e. output) at target time $t$
$\bar{h}_s$ source hidden state at source time $s$
$a_t$ alignment vector at target time $t$
$a_{t_{i}}$ $i$-th element of $a_t$
$c_t$ context vector at target time $t$
$\tilde{h}_t$ attentional hidden state at target time $t$
$K$ number of encoding time steps, a.k.a. max_time
$S_t$ a subset of $\{1 .. K\}$ at target time $t$
$S'_t$ interested encoder hidden states, i.e. $S'_t = \{\bar{h}_i : i \in S_t\}$

What Is Attention?¶

In a sequence-to-sequence architecture, when performing decoding at time $t$, instead of considering only the current input and the previous state, we can derive a what is called a context vector $c_t$ from encoder's outputs at part or all of encoding time steps, this context vector, combined with the current decoder's hidden state $h_t$, produces a what is called a attentional hidden state $\tilde{h}_t$, this $\tilde{h}_t$ can then be used to derive the prediction at $t$ or be fed to the next time step. This process (consulting encoder's outputs when performing decoding) is called attention.

Below are the details of the data flow of this architecture and its implementation in TensorFlow.

Overview of the Architecture¶

The chart below is an overview of the architecture (seq2seq with attention) and its data flow.

Read this chart bottom-to-top, starting at $\cdots, \bar{h}_{s-1}, \bar{h}_s, \bar{h}_{s+1}, \cdots$ and $h_t$.

tf.contrib.seq2seq.LuongAttention is used as example.

Data flow of attention-based seq2seq network

Step 1: Select $S_t$¶

$S'_t = \{\bar{h}_i : i \in S_t\}$ is the set of encoder outputs that we are "paying attention" to, i.e., in a later step, we will derive some other information from this set.

The simplest choice is to let $S'_t = \{\bar{h}_i : 1 \leq i \leq K\}$ for all $t$, this is referred as global attention.

Another choice is to choose a certain $1 \leq p_t \leq K$ and let $$S_t = \{i : 1 \leq i \leq K, p_t - D \leq i \leq p_t + D\}$$ where $D$ is the window size and is chosen empirically.

If the source sequence is roughly the same length as the target sequence, we can simply let $p_t = t$. This is referred as local-m attention ("m" stands for "monotonic").

We can also let

$$p_t = K \cdot \text{sigmoid}(v_p^\top \text{tanh}(W_p h_t))$$

where $v_p$ and $W_p$ are trainable model parameters. This is referred as local-p attention ("p" stands for "predictive").

Step 2: Create Alignment Vector $a_t$¶

The so called alignment vector $a_t$ will be used as the weights of a weighted sum in a later step.

For $s \in S_t$, let $$\text{align}(h_t, \bar{h}_s) = \frac{\text{exp}(\text{score}(h_t, \bar{h}_s))}{\sum_{s' \in S_t}\text{exp}(\text{score}(h_t, \bar{h}_{s'}))}$$

The $\text{score}$ function can be chosen based on the question at hand, these three are provided by the paper:

dot: $score(h_t, \bar{h}_s) = h_t^\top \bar{h}_s$

general: $score(h_t, \bar{h}_s) = h_t^\top W_a \bar{h}_s$

concat: $score(h_t, \bar{h}_s) = v_a^\top \text{tanh}(W_a [h_t; \bar{h}_s])$

If global or local-m attention is chosen, we simply let

$$a_{t_s} = \text{align}(h_t, \bar{h}_s) \quad \text{for}\ s \in S_t$$

If local-p attention is chosen, to favor point near $p_t$, we let

$$a_{t_s} = \text{align}(h_t, \bar{h}_s) \text{exp}(-\frac{(s-p_t)^2}{2\sigma^2}) \quad \text{for}\ s \in S_t$$

where $\sigma$ is a hyper parameter, and is empirically chosen to be $\frac{D}{2}$.

Step 3: Use $S'_t$ and $a_t$ to Create the Context Vector $c_t$¶

The context vector is a weighted sum of the interested source hidden states:

$$c_t = \sum_{s \in S_t}a_{t_s}\bar{h}_s$$

Step 4: Create Attentional Hidden State $\tilde{h}_t$¶

The current context vector $c_t$, and current decoder hidden state $h_t$ is used to create the so called attentional hidden state $\tilde{h}_t$:

$$\tilde{h}_t = tanh(W_c[c_t; h_t])$$

Step 5: Optionally, Make A Prediction¶

For example, we can sample from the distribution $\text{softmax}(W_o\tilde{h}_t)$ to make a prediction at target time $t$, where $W_o$ is a trainable parameter.

Step 6: Optionally, Input Feeding¶

$\tilde{h}_t$ can be concatenated with the input of the decoder at target time $t+1$ to produce a new version of the input.

An Example: translate string "one hundred and twenty seven" to string "127"¶

This example model translates English number to digits, for example, given the input string:

"one hundred and twenty three thousand, four hundred and fifty six",

the model should output the string:

"123456".

It does so by utilizing a sequence-to-sequence network with attention.

The code can be roughly split into these parts:

  • The boring data preparation
  • The encoder
  • The decoder, this is where the attention happends, see the function wrap_with_attention
  • Build the graph
  • Training
  • Test
  • Visualization of the alignments
In [1]:
import logging
import time
import itertools
import os
import string

import numpy as np
import tensorflow as tf
import tensorflow.contrib as contrib

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logger.info(tf.VERSION)
/usr/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
INFO:__main__:1.8.0
In [2]:
# Named tuple will be used to hold tf variable names and hyper parameters,
# so that no strings are passed around.
def make_namedtuple(name, field_names, field_values):
  import collections
  t = collections.namedtuple(name, field_names)
  for fname, fvalue in zip(field_names, field_values):
    setattr(t, fname, fvalue)
  return t
In [3]:
# Vocabulary information.

def make_lookup(vocabs):

  vocab_to_id = {c: i for i, c in enumerate(vocabs)}
  id_to_vocab = {i: c for c, i in vocab_to_id.items()}
  lookup = make_namedtuple(
    'VocabLookup',
    field_names=['vocab_to_id', 'id_to_vocab'],
    field_values=[vocab_to_id, id_to_vocab])
  return lookup

support_chars = ['<PAD>', '<SOS>', '<EOS>']

# 'three thousand, two hundred and seven' -> '3207'
source_vocabs = [x for x in string.ascii_lowercase] + [' ', ','] + support_chars
target_vocabs = [x for x in string.digits] + support_chars

source_lookup = make_lookup(source_vocabs)
target_lookup = make_lookup(target_vocabs)
In [4]:
def int_to_word(n):
  # Convert integer to its English form, where 0 <= n < 1e9
  # e.g. int_to_word(27) == 'twenty seven'
  lookup = {
    0: 'zero', 1: 'one', 2: 'two', 3: 'three', 4: 'four', 5: 'five',
    6: 'six', 7: 'seven', 8: 'eight', 9: 'nine', 10: 'ten',
    11: 'eleven', 12: 'twelve', 13: 'thirteen', 14: 'fourteen',
    15: 'fifteen', 16: 'sixteen', 17: 'seventeen', 18: 'eighteen',
    19: 'nineteen', 20: 'twenty', 30: 'thirty', 40: 'forty',
    50: 'fifty', 60: 'sixty', 70: 'seventy', 80: 'eighty', 90: 'ninety'}

  def f(rest, s):
    if rest <= 19:
      return s + [lookup[rest]]
    elif 19 < rest <= 99:
      if rest % 10 == 0:
        return s + [lookup[rest]]
      else:
        ty = lookup[(rest // 10) * 10]
        finger = lookup[rest % 10]
        return s + [ty, finger]
    elif 99 < rest <= 999:
      q = rest // 100
      r = rest % 100
      if r == 0:
        return s + [lookup[q], 'hundred']
      else:
        return s + [lookup[q], 'hundred', 'and'] + f(r, [])
    else:
      if rest >= int(1e9):
        raise ValueError(n)
      elif rest >= int(1e6):
        q = rest // int(1e6)
        r = rest % int(1e6)
        if r == 0:
          return s + f(q, []) + ['million']
        else:
          return s + f(q, []) + ['million,'] + f(r, [])
      elif int(1e3) <= rest < int(1e6):
        q = rest // int(1e3)
        r = rest % int(1e3)
        if r == 0:
          return s + f(q, []) + ['thousand']
        else:
          return s + f(q, []) + ['thousand,'] + f(r, [])
      else:
        raise ValueError(n)

  return ' '.join(f(n, []))
In [5]:
test_ints = [
  0, 1, 7, 17, 20, 30, 45, 99,
  100, 101, 123, 999,
  1000, 1001, 1234, 9999,
  123456, 234578, 999999,
  1000000, 1000001, 9123456,
  999123456, 123000789]
In [6]:
for n in test_ints:
  print(f'{n:>11,} {int_to_word(n)}')
          0 zero
          1 one
          7 seven
         17 seventeen
         20 twenty
         30 thirty
         45 forty five
         99 ninety nine
        100 one hundred
        101 one hundred and one
        123 one hundred and twenty three
        999 nine hundred and ninety nine
      1,000 one thousand
      1,001 one thousand, one
      1,234 one thousand, two hundred and thirty four
      9,999 nine thousand, nine hundred and ninety nine
    123,456 one hundred and twenty three thousand, four hundred and fifty six
    234,578 two hundred and thirty four thousand, five hundred and seventy eight
    999,999 nine hundred and ninety nine thousand, nine hundred and ninety nine
  1,000,000 one million
  1,000,001 one million, one
  9,123,456 nine million, one hundred and twenty three thousand, four hundred and fifty six
999,123,456 nine hundred and ninety nine million, one hundred and twenty three thousand, four hundred and fifty six
123,000,789 one hundred and twenty three million, seven hundred and eighty nine
In [7]:
# Make batches.
def pad_right(lol, vocab_to_id):
  # Pad right with <PAD>, such that all rows have the same length.
  row_lengths = [len(row) for row in lol]
  max_length = np.max(row_lengths)
  arr = np.ndarray((len(lol), max_length))
  arr.fill(vocab_to_id['<PAD>'])
  for i, length, row in zip(range(len(lol)), row_lengths, lol):
    arr[i, :length] = row
  return arr, row_lengths


def push_sos(arr, vocab_to_id):
  # Prepend each row with <SOS>, and drop the last column,
  # such that the length is unchanged.
  soses = np.ndarray((arr.shape[0], 1))
  soses.fill(vocab_to_id['<SOS>'])
  with_sos = np.concatenate([soses, arr], axis=1)
  no_last = with_sos[:, :-1]
  return no_last


def make_batch(batch_size, lower, upper, source_lookup, target_lookup):
  # Make one batch.
  target_numbers = np.random.randint(low=lower, high=upper, size=(batch_size,))
  target_strings = [str(n) for n in target_numbers]
  source_strings = [int_to_word(n) for n in target_numbers]

  source_ids = [[source_lookup.vocab_to_id[vocab] for vocab in seq]
                for seq in source_strings]
  target_ids = [[target_lookup.vocab_to_id[vocab] for vocab in seq] + [target_lookup.vocab_to_id['<EOS>']]
                for seq in target_strings]

  padded_source_ids, source_lengths = pad_right(source_ids, source_lookup.vocab_to_id)
  padded_target_ids, target_lengths = pad_right(target_ids, target_lookup.vocab_to_id)
  decoder_input_ids = push_sos(padded_target_ids, target_lookup.vocab_to_id)
  batch = (padded_source_ids, source_lengths,
           padded_target_ids, target_lengths,
           decoder_input_ids)
  return batch, set(target_numbers)


def make_batches(hparams, num_batches, source_lookup, target_lookup):
  batches, seens = [], set()
  for _ in range(num_batches):
    batch, seen = make_batch(
      hparams.batch_size, hparams.data_lower, hparams.data_upper,
      source_lookup=source_lookup,
      target_lookup=target_lookup)
    batches.append(batch)
    seens = seens | seen
  return batches, seens


# An iterator that produces feed dicts.
def get_feed(all_batches, graph, handles):
  for (padded_source_ids, source_lengths,
       padded_target_ids, target_lengths,
       decoder_input_ids
       ) in all_batches:
    yield {graph.get_tensor_by_name(f'{handles.batch_size}:0'): len(source_lengths),
           graph.get_tensor_by_name(f'{handles.input_ids}:0'): padded_source_ids,
           graph.get_tensor_by_name(f'{handles.input_lengths}:0'): source_lengths,
           graph.get_tensor_by_name(f'{handles.decoder_input_ids}:0'): decoder_input_ids,
           graph.get_tensor_by_name(f'{handles.decoder_input_lengths}:0'): target_lengths,
           graph.get_tensor_by_name(f'{handles.target_ids}:0'): padded_target_ids,
           graph.get_tensor_by_name(f'{handles.target_lengths}:0'): target_lengths,
           graph.get_tensor_by_name(f'{handles.max_decode_iterations}:0'): 2 * np.max(source_lengths)}
In [8]:
handle_fields = [
  'input_ids',
  'input_lengths',

  'decoder_input_ids',
  'decoder_input_lengths',
  'max_decode_iterations',

  'target_ids',
  'target_lengths',

  'train_loss',
  'optimize',
  'infer_logits',

  'train_alignments',
  'train_alignments_summary',
  'attention_cell',

  'batch_size',
  'train_loss_summary',
  'val_loss_summary']

# Various tf variable names.
Handles = make_namedtuple('Handles', handle_fields, handle_fields)
In [9]:
# An encoder is just an RNN, we feed it the input sequence,
# use its (slightly transformed) final state as the initial state of the decoder,
# and use its outputs when doing attention.

def build_encoder(source_vocab_size, source_embedding_dim,
                  rnn_size, rnn_layers,
                  batch_size, handles):

  # Input to the encoder.
  # Each entry is a vocabulary id,
  # and every sequence (i.e. row) is padded to have the same length.
  # (batch_size, sequence_length)
  input_ids = tf.placeholder(tf.int32, (None, None), name=handles.input_ids)

  # Length of each sequence, without counting the padding,
  # This is also the length of the outputs.
  # (batch_size,)
  sequence_lengths = tf.placeholder(tf.int32, (None,), name=handles.input_lengths)

  # Embedded version of input_ids.
  # (batch_size, max_time, source_embedding_dim)
  # where max_time == tf.shape(input_ids)[1] == tf.reduce_max(sequence_lengths)
  inputs_embedded = contrib.layers.embed_sequence(
    ids=input_ids,
    vocab_size=source_vocab_size,
    embed_dim=source_embedding_dim)

  def build_cell():
    # The cell advance one time step in one layer.
    cell = tf.nn.rnn_cell.LSTMCell(
      num_units=rnn_size,
      initializer=tf.random_uniform_initializer(-0.1, 0.1))
    return cell

  # Conceptually:
  #
  # multi_rnn_cell(input,  [layer1_state,     layer2_state,     ...])
  # ->            (output, [new_layer1_state, new_layer2_state, ...])
  #
  # i.e. advance one time step in each layer.
  multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(
    [build_cell() for _ in range(rnn_layers)])

  zero_state = multi_rnn_cell.zero_state(batch_size, tf.float32)

  # Advance sequence_lengths[i] time steps for layer i, for all i.
  # outputs :: [batch_size, max_time, rnn_size]
  # final_state :: (rnn_layers, 2=lstm_state_tuple_size, batch_size, rnn_size)
  outputs, final_state = tf.nn.dynamic_rnn(
    cell=multi_rnn_cell,
    inputs=inputs_embedded,
    sequence_length=sequence_lengths,
    initial_state=zero_state)

  return sequence_lengths, (outputs, final_state)
In [10]:
def wrap_with_attention(cell,
                        num_units_memory, attention_size,
                        encoder_outputs, encoder_output_lengths,
                        attention_cell_name):
  # Wrap cell the LuongAttention.
  #
  # If you simplify enough, an AttentionMechanism is a function with the signature:
  # memory -> query -> alignments
  #
  # memory is usually the encoder's hidden states,
  # it is transformed with a dense layer that has num_units outputs,
  # the result of this transformation is called keys in tensorflow's source code.
  #
  # query is usually the decoder's hidden state at the current time step.
  #
  # Unless you manage memory and/or transforms the query yourself,
  # num_units must be equal to decoder's hidden state size,
  # because internally, the dot (mentioned in Step 2) is used to
  # combine the keys and the decoder hidden state.
  #
  # The score calculated using dot, is passed to probability_fn,
  # which by default is softmax, and turned into a probability distribution.
  #
  # In the default setting, when the entire encoder's outputs are used as memory,
  # LuongAttention implements the global attention with input feeding,
  # and dot product is used as the score function.
  #
  attention_mechanism = contrib.seq2seq.LuongAttention(

    num_units=num_units_memory,

    memory=encoder_outputs,

    # used to mask padding positions
    memory_sequence_length=encoder_output_lengths,

    # convert score to probabilities, default is tf.nn.softmax
    probability_fn=tf.nn.softmax,

    # if memory_sequence_length is not None,
    # then a mask is created from it, and score is transformed using this mask,
    # with true values from the original score,
    # and false value score_mask_value,
    # the default is -inf, when combined with probability_fn softmax,
    # gives padding positions near-zero probabilities.
    # we choose to use the default.
    score_mask_value=None)

  # AttentionWrapper wraps a RNNCell to get another RNNCell,
  # handling attention along the way.
  #
  # data flow:
  #
  # cell_inputs = cell_input_fn(inputs, attentional_hidden_state_at_previous_time)
  # cell_output, next_cell_state = cell(cell_inputs, cell_state)
  # alignments, _unused = attention_mechanism(cell_output, _unused)
  # context = matmul(alignments, masked(attention_mechanism.memory))
  # attention =
  #   if attention_layer
  #   then attention_layer(concat([cell_output, context], 1))
  #   else context
  # output =
  #   if output_attention
  #   then attention
  #   else cell_output
  #
  new_cell = contrib.seq2seq.AttentionWrapper(
    # the original cell to wrap
    cell=cell,

    attention_mechanism=attention_mechanism,

    # size of attentional hidden state size
    attention_layer_size=attention_size,

    # can be used to enable input feeding
    # the default is: lambda inputs, attention: array_ops.concat([inputs, attention], -1)
    # which is input feeding
    cell_input_fn=None,

    # store all alignment history for visualization purpose
    alignment_history=True,

    # output the original cell's output (False),
    # or output the attentional hidden state (True)
    output_attention=True,

    name=attention_cell_name)

  return new_cell
In [11]:
# The decoder is a RNN with LuongAttention.
#
# The final state of the encoder is wrapped and passed as the initial state of the decoder,
# and a start of sequence (SOS) symbol, (its embedding vector, to be precise),
# is used as the input at time step 0.
# (The attentional hidden state is wrapped in the AttentionWrapper's state,
# so we don't have to manually passing the attentional hidden state at time 0,
# i.e. the cell after the wrapping accept the same inputs with the cell before the wrapping)
#
# The output at each time step is projected (using a dense layer)
# to have target_vocab_size logits.
# When making inference, we sample from the output at time t,
# and use it as the input at time t+1.
#
# When training, we discard the RNN output,
# and feed it the truth corresponding to each time step.
def build_decoder(encoder_state,
                  encoder_outputs, encoder_output_lengths,
                  target_vocab_size, target_embedding_dim,
                  rnn_size, rnn_layers,
                  attention_size,
                  batch_size, max_decode_iterations, target_vocab_to_id, handles):

  # Input sequence to the decoder, this is the truth, and only used during training.
  # The first element of each sequence is always the SOS symbol.
  # (batch_size, sequence_length)
  input_ids = tf.placeholder(tf.int32, (None, None), name=handles.decoder_input_ids)

  # Length of input_ids, without counting padding.
  # (batch_size,)
  sequence_lengths = tf.placeholder(tf.int32, (None,), name=handles.decoder_input_lengths)

  input_embeddings = tf.Variable(tf.random_uniform((target_vocab_size, target_embedding_dim)))

  inputs_embedded = tf.nn.embedding_lookup(
    params=input_embeddings,
    ids=input_ids)

  def build_cell():
    cell = tf.nn.rnn_cell.LSTMCell(
      num_units=rnn_size,
      initializer=tf.random_uniform_initializer(-0.1, 0.1))
    return cell

  multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(
    [build_cell() for _ in range(rnn_layers)])

  # Notice that we wrap with attention after layering,
  # this is consistent with that described in the flow chart.
  cell_with_attention = wrap_with_attention(
    cell=multi_rnn_cell,
    num_units_memory=rnn_size,
    attention_size=attention_size,
    encoder_outputs=encoder_outputs,
    encoder_output_lengths=encoder_output_lengths,
    attention_cell_name=handles.attention_cell)

  # Transform the RNN output to logits, so that we can sample from it.
  projection_layer = tf.layers.Dense(
    units=target_vocab_size,
    activation=None,
    use_bias=False)

  def make_logits(helper, reuse):
    with tf.variable_scope('decoder', reuse=reuse):
      initial_state = cell_with_attention.zero_state(batch_size=batch_size, dtype=tf.float32)
      initial_state = initial_state.clone(cell_state=encoder_state)
      decoder = contrib.seq2seq.BasicDecoder(
        cell=cell_with_attention,
        helper=helper,
        initial_state=initial_state,
        output_layer=projection_layer)

      final_outputs, final_state, _final_sequence_length = contrib.seq2seq.dynamic_decode(
        decoder=decoder,
        maximum_iterations=max_decode_iterations,
        impute_finished=True)

      # rnn_output :: (batch_size, max_time, vocab_size)
      return final_outputs.rnn_output, final_state

  # The Helper is used to sample from logits,
  # you can swap one Helper with another to get different sampling behaviour.
  #
  # At time t, a TrainingHelper just read data from inputs[:, t],
  # and use it as the input at t.
  train_helper = contrib.seq2seq.TrainingHelper(
    inputs=inputs_embedded,
    sequence_length=sequence_lengths)

  # Greedy, as in GreedyEmbeddingHelper, usually means that a max() is taken.
  # So at time t, output at t-1, which has the shape (vocab_size,),
  # (considering only one sample), is sampled from, and is used as the input at time t,
  # i.e. input at t = lookup_embedding(embedding, argmax(output at t-1))
  infer_helper = contrib.seq2seq.GreedyEmbeddingHelper(
    embedding=input_embeddings,
    start_tokens=tf.tile(tf.constant([target_vocab_to_id['<SOS>']]), [batch_size]),
    end_token=target_vocab_to_id['<EOS>'])

  train_logits, train_final_state = make_logits(train_helper, reuse=False)

  # For visualization
  alignments = train_final_state.alignment_history.stack()
  alignments = tf.transpose(alignments, perm=[1, 0, 2], name=handles.train_alignments)

  # Use the same variables used in training.
  infer_logits, _val_final_state = make_logits(infer_helper, reuse=True)
  tf.identity(infer_logits, name=handles.infer_logits)

  return train_logits, alignments
In [12]:
# Create the graph.
def build_graph(handles, hparams, target_lookup):
  graph = tf.Graph()
  with graph.as_default():
    batch_size = tf.placeholder(tf.int32, (), name=handles.batch_size)

    input_sequence_lengths, (encoder_outputs, encoder_final_state) = build_encoder(
      source_vocab_size=hparams.source_vocab_size,
      source_embedding_dim=hparams.source_embedding_dim,
      rnn_size=hparams.rnn_size,
      rnn_layers=hparams.rnn_layers,
      batch_size=batch_size,
      handles=handles)

    max_decode_iterations = tf.placeholder(tf.int32, (), name=handles.max_decode_iterations)

    train_logits, train_alignments = build_decoder(
      encoder_state=encoder_final_state,
      encoder_outputs=encoder_outputs,
      encoder_output_lengths=input_sequence_lengths,
      target_vocab_size=hparams.target_vocab_size,
      target_embedding_dim=hparams.target_embedding_dim,
      rnn_size=hparams.rnn_size,
      rnn_layers=hparams.rnn_layers,
      attention_size=hparams.attention_size,
      batch_size=batch_size,
      max_decode_iterations=max_decode_iterations,
      target_vocab_to_id=target_lookup.vocab_to_id,
      handles=handles)

    # Labels. So it has EOS tokens in it.
    # Used only during training.
    # (batch_size, target_sequence_length)
    target_ids = tf.placeholder(tf.int32, (None, None), name=handles.target_ids)

    # Length of target_ids, without counting padding.
    # (batch_size,)
    target_lengths = tf.placeholder(tf.int32, (None,), name=handles.target_lengths)

    # Since out target_ids is effectively of variant length,
    # we mask out those padding positions.
    loss_mask = tf.sequence_mask(lengths=target_lengths, dtype=tf.float32)
    train_loss_ = contrib.seq2seq.sequence_loss(
      logits=train_logits,
      targets=target_ids,
      weights=loss_mask)
    train_loss = tf.identity(train_loss_, name=handles.train_loss)

    tf.summary.scalar(handles.train_loss_summary, train_loss)
    tf.summary.scalar(handles.val_loss_summary, train_loss)

    # The resulting image should have a white main diagonal,
    # because of the relation of the source and target sequence.
    # Each row in the image corresponds to one target time step,
    # thus have a width of max source time.
    tf.summary.image(
      name=handles.train_alignments_summary,
      tensor=tf.cast(tf.expand_dims(train_alignments * 255, axis=3), dtype=tf.uint8),
      max_outputs=10)

    optimizer = tf.train.AdamOptimizer(hparams.lr)
    gradients = optimizer.compute_gradients(train_loss)
    clipped_gradients = [(tf.clip_by_value(grad, -5., 5.), var)
                         for grad, var in gradients if grad is not None]
    optimizer.apply_gradients(clipped_gradients, name=handles.optimize)

  return graph
In [13]:
def restore_model(sess, checkpoint_prefix):
  loader = tf.train.import_meta_graph(checkpoint_prefix + '.meta')
  loader.restore(sess, checkpoint_prefix)

def save_model(sess, checkpoint_prefix):
  saver = tf.train.Saver()
  saver.save(sess, checkpoint_prefix)
  logger.info(f'model saved to {checkpoint_prefix}')
In [14]:
def train(sess, handles, hparams, train_batches, val_batches):
  run_id = time.strftime("%Z%Y-%m%d-%H%M%S", time.gmtime())
  logger.info(('run_id', run_id))

  with tf.summary.FileWriter(
    logdir=hparams.tensorboard_dir(run_id),
    graph=sess.graph
  ) as summary_writer:

    train_loss = sess.graph.get_tensor_by_name(f'{handles.train_loss}:0')
    optimize = sess.graph.get_operation_by_name(handles.optimize)

    train_loss_summary = sess.graph.get_tensor_by_name(f'{handles.train_loss_summary}:0')
    val_loss_summary = sess.graph.get_tensor_by_name(f'{handles.val_loss_summary}:0')

    train_alignments_summary = sess.graph.get_tensor_by_name(f'{handles.train_alignments_summary}:0')

    global_step = 0
    for i_epoch in range(1, hparams.num_epochs + 1):
      time_begin = time.monotonic()
      train_loss_vals = []
      for feed in get_feed(all_batches=train_batches, graph=sess.graph, handles=handles):
        global_step += 1

        (train_loss_val, _optimize_val, summary_val, train_alignments_summary_val
         ) = sess.run([
          train_loss, optimize, train_loss_summary, train_alignments_summary
          ], feed)

        summary_writer.add_summary(summary_val, global_step=global_step)
        summary_writer.add_summary(train_alignments_summary_val, global_step=global_step)
        train_loss_vals.append(train_loss_val)

      val_loss_vals = []
      for feed in get_feed(all_batches=val_batches, graph=sess.graph, handles=handles):
        val_loss_val, summary_val = sess.run([train_loss, val_loss_summary], feed)
        summary_writer.add_summary(summary_val, global_step=global_step)
        val_loss_vals.append(val_loss_val)

      train_loss_val = np.mean(train_loss_vals[-len(val_loss_vals):])
      val_loss_val = np.mean(val_loss_vals)

      time_end = time.monotonic()
      logger.info(' '.join([
        f'epoch={i_epoch:0{len(str(hparams.num_epochs))}d}/{hparams.num_epochs}',
        f'train_loss={train_loss_val:.4f}',
        f'val_loss={val_loss_val:.4f}',
        f'duration={time_end-time_begin:.4f}s']))

  save_model(sess, hparams.checkpoint_prefix)
In [15]:
Hparams = make_namedtuple('Hparams', *zip(*[
  # rnn state size and number of layers
  # since we are passing the final encoder state as the initial decoder state,
  # this is shared by both the encoder and the decoder
  # another thought is to use a dense layer to bridge the encoder state and decoder state
  ('rnn_size', 30),
  ('rnn_layers', 2),

  ('attention_size', 60),

  ('source_vocab_size', len(source_vocabs)),
  ('target_vocab_size', len(target_vocabs)),

  ('source_embedding_dim', int(len(source_vocabs) / 1.5)),
  ('target_embedding_dim', int(len(target_vocabs) / 1.5)),

  # ('batch_size', 3),
  # ('num_epochs', 2),
  # ('num_train_batches', 4),
  # ('num_val_batches', 4),

  ('batch_size', 128),
  ('num_epochs', 30),
  ('num_train_batches', 100),
  ('num_val_batches', 10),

  ('lr', 0.001),
  ('data_lower', 0),
  ('data_upper', 20 * 128 * (100 + 10)),

  ('checkpoint_prefix', 'checkpoints/attention_word_to_number/model'),
  ('tensorboard_dir', lambda run_id: f'tensorboard/attention_word_to_number/{run_id}')
]))
os.makedirs(os.path.dirname(Hparams.checkpoint_prefix), exist_ok=True)
In [16]:
train_batches, train_seen = make_batches(
  Hparams, Hparams.num_train_batches,
  source_lookup=source_lookup,
  target_lookup=target_lookup)

val_batches, val_seen = make_batches(
  Hparams, Hparams.num_val_batches,
  source_lookup=source_lookup,
  target_lookup=target_lookup)
seen = train_seen | val_seen

logger.info(f'seen_data/all_data = {len(seen)}/{Hparams.data_upper - Hparams.data_lower} = {len(seen)/Hparams.data_upper - Hparams.data_lower:.4f}')
INFO:__main__:seen_data/all_data = 13763/281600 = 0.0489
In [17]:
# Train the model from scratch.
def cold_train(handles, hparams, train_batches, val_batches):
  with tf.Session(graph=build_graph(handles, hparams, target_lookup)) as sess:
    sess.run(tf.global_variables_initializer())
    save_model(sess, Hparams.checkpoint_prefix)
    train(sess, handles, hparams, train_batches, val_batches)


cold_train(Handles, Hparams, train_batches, val_batches)
INFO:__main__:model saved to checkpoints/attention_word_to_number/model
INFO:__main__:('run_id', 'GMT2018-0518-034110')
INFO:__main__:epoch=01/30 train_loss=1.9008 val_loss=1.8892 duration=12.1588s
INFO:__main__:epoch=02/30 train_loss=1.8225 val_loss=1.8139 duration=11.2319s
INFO:__main__:epoch=03/30 train_loss=1.8000 val_loss=1.7903 duration=13.4254s
INFO:__main__:epoch=04/30 train_loss=1.6805 val_loss=1.6735 duration=11.9208s
INFO:__main__:epoch=05/30 train_loss=1.5596 val_loss=1.5503 duration=14.9080s
INFO:__main__:epoch=06/30 train_loss=1.4357 val_loss=1.4171 duration=12.4353s
INFO:__main__:epoch=07/30 train_loss=1.2636 val_loss=1.2330 duration=11.1266s
INFO:__main__:epoch=08/30 train_loss=1.0619 val_loss=1.0561 duration=11.7510s
INFO:__main__:epoch=09/30 train_loss=0.8719 val_loss=0.8432 duration=12.9790s
INFO:__main__:epoch=10/30 train_loss=0.6670 val_loss=0.6476 duration=12.9763s
INFO:__main__:epoch=11/30 train_loss=0.4411 val_loss=0.4329 duration=15.3058s
INFO:__main__:epoch=12/30 train_loss=0.2618 val_loss=0.2851 duration=10.9570s
INFO:__main__:epoch=13/30 train_loss=0.1619 val_loss=0.1557 duration=12.4879s
INFO:__main__:epoch=14/30 train_loss=0.1110 val_loss=0.1042 duration=11.8751s
INFO:__main__:epoch=15/30 train_loss=0.0673 val_loss=0.0696 duration=11.6231s
INFO:__main__:epoch=16/30 train_loss=0.0461 val_loss=0.0486 duration=14.5002s
INFO:__main__:epoch=17/30 train_loss=0.0518 val_loss=0.0990 duration=20.0289s
INFO:__main__:epoch=18/30 train_loss=0.0531 val_loss=0.0513 duration=18.9734s
INFO:__main__:epoch=19/30 train_loss=0.0289 val_loss=0.0301 duration=11.5087s
INFO:__main__:epoch=20/30 train_loss=0.0207 val_loss=0.0223 duration=12.8686s
INFO:__main__:epoch=21/30 train_loss=0.0162 val_loss=0.0181 duration=11.4668s
INFO:__main__:epoch=22/30 train_loss=0.0130 val_loss=0.0152 duration=13.7870s
INFO:__main__:epoch=23/30 train_loss=0.0108 val_loss=0.0130 duration=22.1264s
INFO:__main__:epoch=24/30 train_loss=0.0092 val_loss=0.0113 duration=21.1544s
INFO:__main__:epoch=25/30 train_loss=0.0079 val_loss=0.0100 duration=21.9507s
INFO:__main__:epoch=26/30 train_loss=0.0068 val_loss=0.0090 duration=19.0388s
INFO:__main__:epoch=27/30 train_loss=0.0059 val_loss=0.0082 duration=19.1095s
INFO:__main__:epoch=28/30 train_loss=0.0052 val_loss=0.0075 duration=19.1081s
INFO:__main__:epoch=29/30 train_loss=0.0045 val_loss=0.0069 duration=18.9062s
INFO:__main__:epoch=30/30 train_loss=0.0040 val_loss=0.0065 duration=18.9704s
INFO:__main__:model saved to checkpoints/attention_word_to_number/model
In [ ]:
# If the loss is still too big and decreasing,
# we can load the trained model and continue training.
def warm_train(handles, hparams, train_batches, val_batches):
  with tf.Session(graph=tf.Graph()) as sess:
    restore_model(sess, hparams.checkpoint_prefix)
    train(sess, handles, hparams, train_batches, val_batches)


warmHparams = Hparams
warmHparams.num_epochs = 5
warm_train(Handles, warmHparams, train_batches, val_batches)
In [19]:
# See how we are doing.

def translate(input_ids_var_length, handles, hparams, source_lookup):
  graph = tf.Graph()
  with tf.Session(graph=graph) as sess:
    restore_model(sess, hparams.checkpoint_prefix)

    input_ids, input_lengths = pad_right(input_ids_var_length, source_lookup.vocab_to_id)
    feed = {graph.get_tensor_by_name(f'{handles.batch_size}:0'): len(input_ids_var_length),
            graph.get_tensor_by_name(f'{handles.input_ids}:0'): input_ids,
            graph.get_tensor_by_name(f'{handles.input_lengths}:0'): input_lengths,
            graph.get_tensor_by_name(f'{handles.max_decode_iterations}:0'): 2 * np.max(input_lengths)}

    infer_logits = graph.get_tensor_by_name('infer_logits:0')
    logits_val = sess.run(infer_logits, feed)
    target_ids = np.argmax(logits_val, axis=2)

    return target_ids


def lookup(sequences, d):
  return [[d[k] for k in seq] for seq in sequences]

def make_unseen(hparams, seen, n):
  made = []
  max_loop = n * ((hparams.data_upper - hparams.data_lower) - len(seen))
  loop_counter = 0
  while len(made) < n:
    while True:
      loop_counter += 1
      if loop_counter > max_loop:
        raise Exception('Reached max loop')
      x, = np.random.randint(hparams.data_lower, hparams.data_upper, (1,))
      if x not in seen:
        made.append(x)
        break
  return made


def run_test(hparams, handles, test_ints, seen, source_lookup, target_lookup):
  numbers = make_unseen(hparams, seen, 20) + test_ints
  batch = [int_to_word(x) for x in numbers]
  predict_ids = translate(lookup(batch, source_lookup.vocab_to_id), handles, hparams, source_lookup)
  print('columns are: is correct, is similar data, truth, predicted, source sentence')
  for words, target_ids, x in zip(batch, predict_ids, numbers):
    target_ids_chopped = itertools.takewhile(lambda x: x != target_lookup.vocab_to_id['<EOS>'], target_ids)
    result = ''.join(lookup([target_ids_chopped], target_lookup.id_to_vocab)[0])
    marker = 'Y' if str(x) == result else 'N'
    is_similar_data = 'Y' if hparams.data_lower <= x < hparams.data_upper else 'N'
    print(f'{marker} {is_similar_data} {x:>11} {"<empty>" if result=="" else result:<11} {words}')

run_test(Hparams, Handles, test_ints, seen, source_lookup, target_lookup)
INFO:tensorflow:Restoring parameters from checkpoints/attention_word_to_number/model
INFO:tensorflow:Restoring parameters from checkpoints/attention_word_to_number/model
columns are: is correct, is similar data, truth, predicted, source sentence
Y Y      213027 213027      two hundred and thirteen thousand, twenty seven
Y Y      151919 151919      one hundred and fifty one thousand, nine hundred and nineteen
Y Y      259641 259641      two hundred and fifty nine thousand, six hundred and forty one
Y Y      279528 279528      two hundred and seventy nine thousand, five hundred and twenty eight
Y Y      109454 109454      one hundred and nine thousand, four hundred and fifty four
Y Y      146741 146741      one hundred and forty six thousand, seven hundred and forty one
Y Y       80584 80584       eighty thousand, five hundred and eighty four
Y Y      264908 264908      two hundred and sixty four thousand, nine hundred and eight
Y Y      117973 117973      one hundred and seventeen thousand, nine hundred and seventy three
Y Y      108302 108302      one hundred and eight thousand, three hundred and two
Y Y      131445 131445      one hundred and thirty one thousand, four hundred and forty five
Y Y      280951 280951      two hundred and eighty thousand, nine hundred and fifty one
Y Y       27503 27503       twenty seven thousand, five hundred and three
Y Y      188597 188597      one hundred and eighty eight thousand, five hundred and ninety seven
Y Y       55731 55731       fifty five thousand, seven hundred and thirty one
Y Y      168913 168913      one hundred and sixty eight thousand, nine hundred and thirteen
Y Y      105271 105271      one hundred and five thousand, two hundred and seventy one
Y Y      186516 186516      one hundred and eighty six thousand, five hundred and sixteen
Y Y      229783 229783      two hundred and twenty nine thousand, seven hundred and eighty three
Y Y      279002 279002      two hundred and seventy nine thousand, two
N Y           0 003         zero
N Y           1 9           one
N Y           7 70          seven
N Y          17 77          seventeen
Y Y          20 20          twenty
N Y          30 32          thirty
Y Y          45 45          forty five
Y Y          99 99          ninety nine
N Y         100 010         one hundred
N Y         101 011         one hundred and one
N Y         123 223         one hundred and twenty three
Y Y         999 999         nine hundred and ninety nine
N Y        1000 1001        one thousand
N Y        1001 1010        one thousand, one
Y Y        1234 1234        one thousand, two hundred and thirty four
Y Y        9999 9999        nine thousand, nine hundred and ninety nine
Y Y      123456 123456      one hundred and twenty three thousand, four hundred and fifty six
Y Y      234578 234578      two hundred and thirty four thousand, five hundred and seventy eight
N N      999999 199999      nine hundred and ninety nine thousand, nine hundred and ninety nine
N N     1000000 9992        one million
N N     1000001 1999        one million, one
N N     9123456 192156      nine million, one hundred and twenty three thousand, four hundred and fifty six
N N   999123456 199166      nine hundred and ninety nine million, one hundred and twenty three thousand, four hundred and fifty six
N N   123000789 123975      one hundred and twenty three million, seven hundred and eighty nine

As shown above, the model got all the first 20 translations right. Overall, for similar data, that is data in the range of [Hparams.data_lower, Hparams.data_upper), the model is doing well: all longer sequence (> 1000) but one ("one thousand, one") are correctly translated, it is not doing so good on shorter sequence and dissimilar data. Since an excellent model is not the concern of this notebook, we don't bother to tweak it.

Visualization of the Alignment Vectors¶

Because of the way numbers are pronunced in English, as the decoder advances from left to right, so should its attention on the source sequence, thus if we organize the attention vectors into a target-time major matrix, we should see higher values on the main diagonal. The image summary confirms this:

Image Summary of Alignment Vector