The bug about using hooks and MirroredStrategy in tf.estimator.Estimator

Robin Dong 2018-11-09 14:25

When I was using MirroedStrategy in my tf.estimator.Estimator:

distribution = tf.contrib.distribute.MirroredStrategy(
      ["/device:GPU:0", "/device:GPU:1"])
config = tf.estimator.RunConfig(train_distribute=distribution,
                                  eval_distribute=distribution)
estimator = tf.estimator.Estimator(
      model_fn=build_model_fn_optimizer(), config=config)
estimator.train(input_fn=input_fn, steps=10)

and add hooks for training:

logging_hook = tf.train.LoggingTensorHook({'logits' : logits})
    return tf.estimator.EstimatorSpec(mode, loss=loss_fn(), train_op=train_op, training_hooks = [logging_hook])

The tensorflow report errors:

File "/usr/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 356, in train
    loss = self._train_model(input_fn, hooks, saving_listeners)
  File "/usr/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1179, in _train_model
    return self._train_model_distributed(input_fn, hooks, saving_listeners)
  File "/usr/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1309, in _train_model_distributed
    grouped_estimator_spec.training_hooks)
  File "/usr/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1305, in get_hooks_from_the_first_device
    for per_device_hook in per_device_hooks
  File "/usr/lib/python3.6/site-packages/tensorflow/python/estimator/estimator.py", line 1305, in <listcomp>
    for per_device_hook in per_device_hooks
AttributeError: 'Estimator' object has no attribute '_distribution'

Without finding any answers on google, I have to look into the code of ‘estimator.py’ in tensorflow. Fortunately, the code defect is obvious:

scaffold = _combine_distributed_scaffold(
            grouped_estimator_spec.scaffold, self._train_distribution)

        # TODO(yuefengz): add a test for unwrapping per_device_hooks.
        def get_hooks_from_the_first_device(per_device_hooks):
          return [
              self._distribution.unwrap(per_device_hook)[0]
              for per_device_hook in per_device_hooks
          ]
            
        training_hooks = get_hooks_from_the_first_device(
            grouped_estimator_spec.training_hooks)

class Estimator havn’t any private argument named ‘_distribution’ but only have ‘_train_distribution’ and ‘_eval_distribution’. So the fix is just change ‘self._distribution.unwrap(per_device_hook)[0]’ to ‘self._train_distribution.unwrap(per_device_hook)[0]’.

I had submitted arequest pullfor tensorflow to fix this bug in branch 1.11

[返回] [原文链接]