seq2seq_addition.py
1    """ 
2    The task is to train a simple seq2seq model that can do addition. 
3     
4    For example, given a *string* input "1+26", 
5    the model should output a *string* "27". 
6    """ 
7     
8
# %% 9 import logging 10 import time 11 import itertools 12 import os 13 14 import numpy as np 15 import tensorflow as tf 16 import tensorflow.contrib as contrib 17
18
# %% 19 logging.basicConfig(level=logging.DEBUG) 20 logger = logging.getLogger(__name__) 21 logger.info(tf.VERSION) 22 23
24
# %% 25 # Named tuple will be used to hold tf variable names and hyper parameters, 26 # so that no strings are passed around. 27 def make_namedtuple(name, field_names, field_values): 28 import collections 29 t = collections.namedtuple(name, field_names) 30 for fname, fvalue in zip(field_names, field_values): 31 setattr(t, fname, fvalue) 32 return t 33 34
35
# %% 36 handle_fields = [ 37 'input_ids', 38 'input_lengths', 39 40 'decoder_input_ids', 41 'decoder_input_lengths', 42 'max_decode_iterations', 43 44 'target_ids', 45 'target_lengths', 46 47 'train_loss', 48 'optimize', 49 'infer_logits', 50 51 'batch_size', 52 'train_loss_summary', 53 'val_loss_summary'] 54 55 # Various tf variable names. 56 Handles = make_namedtuple('Handles', handle_fields, handle_fields) 57 58
59
# %% 60 # An encoder is just an RNN. 61 # We feed it the input sequence, harvest its final state, 62 # and feed that to the decoder. 63 def build_encoder(source_vocab_size, source_embedding_dim, 64 rnn_size, rnn_layers, 65 batch_size, handles): 66 67 # Input to the encoder. 68 # Each entry is a vocabulary id, 69 # and every sequence (i.e. row) is padded to have the same length. 70 # (batch_size, sequence_length) 71 input_ids = tf.placeholder(tf.int32, (None, None), name=handles.input_ids) 72 73 # Length of each sequence, without counting the padding. 74 # (batch_size,) 75 sequence_lengths = tf.placeholder(tf.int32, (None,), name=handles.input_lengths) 76 77 # Embedded version of input_ids. 78 # (batch_size, max_time, source_embedding_dim) 79 # where max_time == tf.shape(input_ids)[1] == tf.reduce_max(sequence_lengths) 80 inputs_embedded = contrib.layers.embed_sequence( 81 ids=input_ids, 82 vocab_size=source_vocab_size, 83 embed_dim=source_embedding_dim) 84 85 def build_cell(): 86 # The cell advance one time step in one layer. 87 cell = tf.nn.rnn_cell.LSTMCell( 88 num_units=rnn_size, 89 initializer=tf.random_uniform_initializer(-0.1, 0.1)) 90 return cell 91 92 # Conceptially: 93 # 94 # multi_rnn_cell(input, [layer1_state, layer2_state, ...]) 95 # -> (output, [new_layer1_state, new_layer2_state, ...]) 96 # 97 # i.e. advance one time step in each layer. 98 multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell( 99 [build_cell() for _ in range(rnn_layers)]) 100 101 zero_state = multi_rnn_cell.zero_state(batch_size, tf.float32) 102 103 # Advance sequence_lengths[i] time steps for layer i, for all i. 104 # It is a design choice that we only interested in the final state. 105 # (rnn_layers, 2=lstm_state_tuple_size, batch_size, rnn_size) 106 _outputs, final_state = tf.nn.dynamic_rnn( 107 cell=multi_rnn_cell, 108 inputs=inputs_embedded, 109 sequence_length=sequence_lengths, 110 initial_state=zero_state) 111 112 return final_state 113 114
115
# %% 116 # The decoder is another RNN, it is a design choice that 117 # it has the same number of layers and size with the encoder RNN. 118 # 119 # The final state of the encoder is passed as the initial state of the decoder, 120 # and a start of sequence (SOS) symbol, (its embedding vector, to be precise), 121 # is used as the input at time step 0. 122 # 123 # The output at each time step is projected (using a dense layer) to have vocab_size logits. 124 # When making inference, we sample from the output at time t, 125 # and use it as the input at time t+1. 126 # 127 # When training, we discard the RNN output, 128 # and feed it the truth corresponding to each time step. 129 def build_decoder(encoder_state, 130 target_vocab_size, target_embedding_dim, 131 rnn_size, rnn_layers, 132 batch_size, max_decode_iterations, vocab_to_id, handles): 133 134 # Input sequence to the decoder, this is the truth, and only used during training. 135 # The first element of each sequence is always the SOS symbol. 136 # (batch_size, sequence_length) 137 input_ids = tf.placeholder(tf.int32, (None, None), name=handles.decoder_input_ids) 138 139 # Length of input_ids, without counting padding. 140 # (batch_size,) 141 sequence_lengths = tf.placeholder(tf.int32, (None,), name=handles.decoder_input_lengths) 142 143 input_embeddings = tf.Variable(tf.random_uniform((target_vocab_size, target_embedding_dim))) 144 145 inputs_embedded = tf.nn.embedding_lookup( 146 params=input_embeddings, 147 ids=input_ids) 148 149 def build_cell(): 150 cell = tf.nn.rnn_cell.LSTMCell( 151 num_units=rnn_size, 152 initializer=tf.random_uniform_initializer(-0.1, 0.1)) 153 return cell 154 155 multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell( 156 [build_cell() for _ in range(rnn_layers)]) 157 158 # Transform the RNN output to logits, so that we can sample from it. 159 projection_layer = tf.layers.Dense( 160 units=target_vocab_size, 161 activation=None, 162 use_bias=False) 163 164 def make_logits(helper, reuse): 165 with tf.variable_scope('decoder', reuse=reuse): 166 decoder = contrib.seq2seq.BasicDecoder( 167 cell=multi_rnn_cell, 168 helper=helper, 169 initial_state=encoder_state, 170 output_layer=projection_layer) 171 172 final_outputs, _final_state, _final_sequence_length = contrib.seq2seq.dynamic_decode( 173 decoder=decoder, 174 maximum_iterations=max_decode_iterations, 175 impute_finished=True) 176 177 # (batch_size, max_time, vocab_size) 178 return final_outputs.rnn_output 179 180 # The Helper is used to sample from logits, 181 # you can swap one Helper with another to get different sampling behaviour. 182 # 183 # At time t, a TrainingHelper just read data from inputs[:, t], 184 # and use it as the input at t . 185 train_helper = contrib.seq2seq.TrainingHelper( 186 inputs=inputs_embedded, 187 sequence_length=sequence_lengths) 188 189 # Greedy, as in GreedyEmbeddingHelper, usually means that a max() is taken. 190 # So at time t, output at t-1, which has the shape (vocab_size,), 191 # (considering only one sample), is sampled from, and is used as the input at time t, 192 # i.e. input at t = lookup_embedding(embedding, argmax(output at t-1)) 193 infer_helper = contrib.seq2seq.GreedyEmbeddingHelper( 194 embedding=input_embeddings, 195 start_tokens=tf.tile(tf.constant([vocab_to_id['<SOS>']]), [batch_size]), 196 end_token=vocab_to_id['<EOS>']) 197 198 train_logits = make_logits(train_helper, reuse=False) 199 200 # Use the same variables used in training. 201 infer_logits = make_logits(infer_helper, reuse=True) 202 tf.identity(infer_logits, name=handles.infer_logits) 203 204 return train_logits 205 206
207
# %% 208 # Vocabulary information 209 210 # +: we only handle addition 211 # -: addition of negative number 212 # <PAD>: used to pad the sequence such that they have the same length, 213 # (so that they can for a proper np.ndarray or tf.Tensor) 214 # <SOS>, <EOS>: used to indicates start and end of the sequence 215 vocabs = [str(i) for i in range(0, 10)] + ['+', '-', '<PAD>', '<SOS>', '<EOS>'] 216 217 vocab_to_id = {c: i for i, c in enumerate(vocabs)} 218 id_to_vocab = {i: c for c, i in vocab_to_id.items()} 219 220 logger.info(vocabs) 221 logger.info(vocab_to_id) 222 logger.info(id_to_vocab) 223 224
225
# %% 226 # Data preparation 227 228 def pad_right(lol, vocab_to_id): 229 # Pad right with <PAD>, such that all rows have the same length. 230 row_lengths = [len(row) for row in lol] 231 max_length = np.max(row_lengths) 232 arr = np.ndarray((len(lol), max_length)) 233 arr.fill(vocab_to_id['<PAD>']) 234 for i, length, row in zip(range(len(lol)), row_lengths, lol): 235 arr[i, :length] = row 236 return arr, row_lengths 237 238 239 def push_sos(arr, vocab_to_id): 240 # Prepend each row with <SOS>, and drop the last column, 241 # such that the length is unchanged. 242 soses = np.ndarray((arr.shape[0], 1)) 243 soses.fill(vocab_to_id['<SOS>']) 244 with_sos = np.concatenate([soses, arr], axis=1) 245 no_last = with_sos[:, :-1] 246 return no_last 247 248 def make_batch(batch_size, lower, upper): 249 # Make one batch. 250 xys = np.random.randint(low=lower, high=upper, size=(batch_size, 2)) 251 zs = np.sum(xys, axis=1) 252 253 source_ids = [[vocab_to_id[char] for char in f'{xy[0]}+{xy[1]}'] 254 for xy in list(xys)] 255 target_ids = [[vocab_to_id[char] for char in f'{z}'] + [vocab_to_id['<EOS>']] 256 for z in list(zs)] 257 258 padded_source_ids, source_lengths = pad_right(source_ids, vocab_to_id) 259 padded_target_ids, target_lengths = pad_right(target_ids, vocab_to_id) 260 decoder_input_ids = push_sos(padded_target_ids, vocab_to_id) 261 batch = (padded_source_ids, source_lengths, 262 padded_target_ids, target_lengths, 263 decoder_input_ids) 264 return batch 265 266 267 # An iterator that produces feed dicts. 268 def get_feed(all_batches, graph, handles): 269 for (padded_source_ids, source_lengths, 270 padded_target_ids, target_lengths, 271 decoder_input_ids 272 ) in all_batches: 273 yield {graph.get_tensor_by_name(f'{handles.batch_size}:0'): len(source_lengths), 274 graph.get_tensor_by_name(f'{handles.input_ids}:0'): padded_source_ids, 275 graph.get_tensor_by_name(f'{handles.input_lengths}:0'): source_lengths, 276 graph.get_tensor_by_name(f'{handles.decoder_input_ids}:0'): decoder_input_ids, 277 graph.get_tensor_by_name(f'{handles.decoder_input_lengths}:0'): target_lengths, 278 graph.get_tensor_by_name(f'{handles.target_ids}:0'): padded_target_ids, 279 graph.get_tensor_by_name(f'{handles.target_lengths}:0'): target_lengths, 280 graph.get_tensor_by_name(f'{handles.max_decode_iterations}:0'): 2 * np.max(source_lengths)} 281 282
283
# %% Hyper parameters and alikes 284 285 Hparams = make_namedtuple('Hparams', *zip(*[ 286 # rnn state size and number of layers 287 # since we are passing the final encoder state as the initial decoder state, 288 # this is shared by both the encoder and the decoder 289 # another thought is to use a dense layer to bridge the encoder state and decoder state 290 ('rnn_size', 50), 291 ('rnn_layers', 3), 292 293 ('source_vocab_size', len(vocabs)), 294 ('target_vocab_size', len(vocabs)), 295 296 ('source_embedding_dim', 10), 297 ('target_embedding_dim', 10), 298 299 # ('batch_size', 3), 300 # ('num_epochs', 2), 301 # ('num_train_batches', 4), 302 # ('num_val_batches', 4), 303 304 # NB, you don't want to batch_size * num_train_batches close to (data_upper - data_lower), 305 # otherwise the model may memorise the universe and we will have no test data. 306 ('batch_size', 128), 307 ('num_epochs', 480), 308 ('num_train_batches', 100), 309 ('num_val_batches', 10), 310 311 ('lr', 0.001), 312 ('data_lower', -100), 313 ('data_upper', 100), 314 315 ('checkpoint_prefix', 'checkpoints/seq2seq_addition/model'), 316 ('tensorboard_dir', 'tensorboard') 317 ])) 318 os.makedirs(os.path.dirname(Hparams.checkpoint_prefix), exist_ok=True) 319
320
# %% 321 train_batches = [ 322 make_batch(Hparams.batch_size, Hparams.data_lower, Hparams.data_upper) 323 for _ in range(Hparams.num_train_batches)] 324 325 val_batches = [ 326 make_batch(Hparams.batch_size, Hparams.data_lower, Hparams.data_upper) 327 for _ in range(Hparams.num_val_batches)] 328 329
330
# %% 331 # Create the graph. 332 def make_graph(handles, hparams): 333 graph = tf.Graph() 334 with graph.as_default(): 335 batch_size = tf.placeholder(tf.int32, (), name=handles.batch_size) 336 337 encoder_final_state = build_encoder( 338 source_vocab_size=hparams.source_vocab_size, 339 source_embedding_dim=hparams.source_embedding_dim, 340 rnn_size=hparams.rnn_size, 341 rnn_layers=hparams.rnn_layers, 342 batch_size=batch_size, 343 handles=handles) 344 345 max_decode_iterations = tf.placeholder(tf.int32, (), name=handles.max_decode_iterations) 346 347 train_logits = build_decoder( 348 encoder_state=encoder_final_state, 349 target_vocab_size=hparams.target_vocab_size, 350 target_embedding_dim=hparams.target_embedding_dim, 351 rnn_size=hparams.rnn_size, 352 rnn_layers=hparams.rnn_layers, 353 batch_size=batch_size, 354 max_decode_iterations=max_decode_iterations, 355 vocab_to_id=vocab_to_id, 356 handles=handles) 357 358 # Labels. So it has EOS tokens in it. 359 # Used only during training. 360 # (batch_size, target_sequence_length) 361 target_ids = tf.placeholder(tf.int32, (None, None), name=handles.target_ids) 362 363 # Length of target_ids, without counting padding. 364 # (batch_size,) 365 target_lengths = tf.placeholder(tf.int32, (None,), name=handles.target_lengths) 366 367 # Since out target_ids is effectively of variant length, 368 # we mask out those padding positions. 369 loss_mask = tf.sequence_mask(lengths=target_lengths, dtype=tf.float32) 370 train_loss_ = contrib.seq2seq.sequence_loss( 371 logits=train_logits, 372 targets=target_ids, 373 weights=loss_mask) 374 train_loss = tf.identity(train_loss_, name=handles.train_loss) 375 376 tf.summary.scalar(handles.train_loss_summary, train_loss) 377 tf.summary.scalar(handles.val_loss_summary, train_loss) 378 379 optimizer = tf.train.AdamOptimizer(hparams.lr) 380 gradients = optimizer.compute_gradients(train_loss) 381 clipped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) 382 for grad, var in gradients if grad is not None] 383 optimizer.apply_gradients(clipped_gradients, name=handles.optimize) 384 385 386 return graph 387
388
# %% 389 def restore_model(sess, checkpoint_prefix): 390 loader = tf.train.import_meta_graph(checkpoint_prefix + '.meta') 391 loader.restore(sess, checkpoint_prefix) 392 393 394 def save_model(sess, checkpoint_prefix): 395 saver = tf.train.Saver() 396 saver.save(sess, checkpoint_prefix) 397 logger.info(f'model saved to {checkpoint_prefix}') 398 399
400
# %% 401 def train(sess: tf.Session, handles: Handles, hparams: Hparams) -> None: 402 with tf.summary.FileWriter( 403 logdir=hparams.tensorboard_dir, 404 graph=sess.graph 405 ) as summary_writer: 406 407 train_loss = sess.graph.get_tensor_by_name(f'{handles.train_loss}:0') 408 optimize = sess.graph.get_operation_by_name(handles.optimize) 409 410 train_loss_summary = sess.graph.get_tensor_by_name(f'{handles.train_loss_summary}:0') 411 val_loss_summary = sess.graph.get_tensor_by_name(f'{handles.val_loss_summary}:0') 412 413 global_step = 0 414 for i_epoch in range(1, hparams.num_epochs + 1): 415 time_begin = time.monotonic() 416 train_loss_vals = [] 417 for feed in get_feed(all_batches=train_batches, graph=sess.graph, handles=handles): 418 global_step += 1 419 train_loss_val, _, summary_val = sess.run([train_loss, optimize, train_loss_summary], feed) 420 summary_writer.add_summary(summary_val, global_step=global_step) 421 train_loss_vals.append(train_loss_val) 422 423 val_loss_vals = [] 424 for feed in get_feed(all_batches=val_batches, graph=sess.graph, handles=handles): 425 val_loss_val, summary_val = sess.run([train_loss, val_loss_summary], feed) 426 summary_writer.add_summary(summary_val, global_step=global_step) 427 val_loss_vals.append(val_loss_val) 428 429 train_loss_val = np.mean(train_loss_vals[-len(val_loss_vals):]) 430 val_loss_val = np.mean(val_loss_vals) 431 432 time_end = time.monotonic() 433 logger.info(' '.join([ 434 f'epoch={i_epoch:0{len(str(hparams.num_epochs))}d}/{hparams.num_epochs}', 435 f'train_loss={train_loss_val:.4f}', 436 f'val_loss={val_loss_val:.4f}', 437 f'duration={time_end-time_begin:.4f}s'])) 438 439 save_model(sess, hparams.checkpoint_prefix) 440 441
442
# %% 443 # Train the model from scratch. 444 def cold_train(handles: Handles, hparams: Hparams) -> None: 445 with tf.Session(graph=make_graph(handles, hparams)) as sess: 446 sess.run(tf.global_variables_initializer()) 447 save_model(sess, Hparams.checkpoint_prefix) 448 train(sess, handles, hparams) 449 450 451 cold_train(Handles, Hparams) 452 453 # ... 454 # INFO:__main__:epoch=358/360 train_loss=0.0905 val_loss=0.1683 duratoin=4.9432s 455 # INFO:__main__:epoch=359/360 train_loss=0.1180 val_loss=0.1358 duratoin=4.8489s 456 # INFO:__main__:epoch=360/360 train_loss=0.1868 val_loss=0.1850 duratoin=4.6427s 457 # INFO:__main__:model trained and saved to ckpts/ckpt 458 459
460
# %% 461 # If the loss is still too big and decreasing, 462 # we can load the trained model and continue training. 463 def warm_train(handles: Handles, hparams: Hparams) -> None: 464 with tf.Session(graph=tf.Graph()) as sess: 465 restore_model(sess, hparams.checkpoint_prefix) 466 train(sess, handles, hparams) 467 468 469 warmHparams = Hparams 470 warmHparams.num_epochs = 120 471 warm_train(Handles, warmHparams) 472 473 # ... 474 # INFO:__main__:epoch=118/120 train_loss=0.0716 val_loss=0.1247 duratoin=5.0092s 475 # INFO:__main__:epoch=119/120 train_loss=0.0484 val_loss=0.0832 duratoin=4.9530s 476 # INFO:__main__:epoch=120/120 train_loss=0.0448 val_loss=0.0738 duratoin=4.9770s 477 # INFO:__main__:model trained and saved to ckpts/ckpt 478 479
480
# %% 481 # See how we are doing. 482 483 def translate(input_ids_var_length, handles, hparams): 484 graph = tf.Graph() 485 with tf.Session(graph=graph) as sess: 486 restore_model(sess, hparams.checkpoint_prefix) 487 488 input_ids, input_lengths = pad_right(input_ids_var_length, vocab_to_id) 489 feed = {graph.get_tensor_by_name(f'{handles.batch_size}:0'): len(input_ids_var_length), 490 graph.get_tensor_by_name(f'{handles.input_ids}:0'): input_ids, 491 graph.get_tensor_by_name(f'{handles.input_lengths}:0'): input_lengths, 492 graph.get_tensor_by_name(f'{handles.max_decode_iterations}:0'): 2 * np.max(input_lengths)} 493 494 infer_logits = graph.get_tensor_by_name('infer_logits:0') 495 logits_val = sess.run(infer_logits, feed) 496 target_ids = np.argmax(logits_val, axis=2) 497 498 return target_ids 499 500 501 def lookup(sequences, d): 502 return [[d[k] for k in seq] for seq in sequences] 503 504 505 def ids_to_int(ids, id_to_vocab): 506 num_str = ''.join([id_to_vocab[x] for x in ids]) 507 return int(num_str) 508 509 # All seen training and validating examples. 510 def all_seen(batches, vocab_to_id, id_to_vocab): 511 seen = [] 512 for batch in batches: 513 seqs = batch[0] 514 for seq in seqs: 515 filling_left = True 516 left, right = [], [] 517 for id_ in seq: 518 if filling_left: 519 if id_ == vocab_to_id['+']: 520 filling_left = False 521 else: 522 left.append(id_) 523 else: 524 if id_ == vocab_to_id['<PAD>']: 525 break 526 else: 527 right.append(id_) 528 seen.append((ids_to_int(left, id_to_vocab), ids_to_int(right, id_to_vocab))) 529 return seen 530 531 532 def make_unseen(hparams, seen, n): 533 made = 0 534 res = [] 535 while made < n: 536 while True: 537 x, y = np.random.randint(hparams.data_lower, hparams.data_upper, (2,)) 538 if (x, y) not in seen: 539 res.append((x, y)) 540 made += 1 541 break 542 return res 543 544 545 seen = all_seen(train_batches + val_batches, vocab_to_id, id_to_vocab) 546 xys = make_unseen(Hparams, seen, 10) + [ 547 (1, 1), (12, 12), (23, 32), (0, 1), (0, 0), (50, 50), (98, 98), (99, 99), 548 (100, 99), (101, 102), (123, 54), (142, 173), (256, 254)] 549 550 batch = [f'{x}+{y}' for x, y in xys] 551 predict_ids = translate(lookup(batch, vocab_to_id), Handles, Hparams) 552 for expr, target_ids, (x, y) in zip(batch, predict_ids, xys): 553 target_ids_chopped = itertools.takewhile(lambda x: x != vocab_to_id['<EOS>'], target_ids) 554 result = ''.join(lookup([target_ids_chopped], id_to_vocab)[0]) 555 marker = '✓' if str(x + y) == result else '✗' 556 print(f'{marker} {expr} = {"<empty>" if result=="" else result}') 557 558 # As can be seen from the following output, 559 # the model is doing OK on similar data that it is trained on, (13 / 17 is correct), 560 # and got all dissimilar data wrong. 561 562 # ... 563 # INFO:tensorflow:Restoring parameters from ckpts/ckpt 564 # ✓ 4+29 = 33 565 # ✓ 71+72 = 143 566 # ✓ -24+25 = 1 567 # ✗ 0+77 = 78 568 # ✓ 59+-27 = 32 569 # ✓ 70+33 = 103 570 # ✓ -73+61 = -12 571 # ✓ -79+-73 = -152 572 # ✓ 16+34 = 50 573 # ✓ -41+4 = -37 574 # ✓ 1+1 = 2 575 # ✓ 12+12 = 24 576 # ✓ 23+32 = 55 577 # ✓ 0+1 = 1 578 # ✗ 0+0 = 3 579 # ✓ 50+50 = 100 580 # ✗ 98+98 = 1949 581 # ✗ 99+99 = 1959 582 # ✗ 100+99 = 76 583 # ✗ 101+102 = -77 584 # ✗ 123+54 = 48 585 # ✗ 142+173 = -34 586 # ✗ 256+254 = 30