Regularization loss in ‘slim’ library of Tensorflow

Robin Dong 2018-07-19 15:16

My python code usingslim libraryto train classification model in Tensorflow:

with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(weight_decay = 0.001)):
      logits, _ = mobilenet_v2.mobilenet(images, NUM_CLASSES)

    cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    cross_entropy = tf.reduce_mean(cross_entropy)

    global_step = tf.contrib.framework.get_or_create_global_step()
    train_op = tf.contrib.slim.learning.create_train_op(cross_entropy, opt, global_step = global_step)

It works fine. However, no matter what value the ‘weight_decay’ is, the training accuracy of the model could reach higher than 90% easily. It seems ‘weight_decay’ just doesn’t work.
In order to find out the reason, I reviewed the code of Tensorflow for ‘tf.losses.sparse_softmax_cross_entropy()’:

# tensorflow/python/ops/losses/
def sparse_softmax_cross_entropy(
    labels, logits, weights=1.0, scope=None,
  with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
                      (logits, labels, weights)) as scope:
    # As documented above in Args, labels contain class IDs and logits contains
    # 1 probability per class ID, so we expect rank(logits) - rank(labels) == 1;
    # therefore, expected_rank_diff=1.
    labels, logits, weights = _remove_squeezable_dimensions(
        labels, logits, weights, expected_rank_diff=1)
    losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
    return compute_weighted_loss(
        losses, weights, scope, loss_collection, reduction=reduction)

The ‘losses.sparse_softmax_cross_entropy()’ simply call ‘tf.nn.sparse_softmax_cross_entropy()’. Then let’s look into the implementation of ‘compute_weighted_loss()’:

# tensorflow/python/ops/losses/
def compute_weighted_loss(
    losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES,
      loss = math_ops.cast(loss, input_dtype)
      util.add_loss(loss, loss_collection)
      return loss
What the secret in 'util.add_loss()'?
<pre lang='python' masks='6'>
# tensorflow/python/ops/losses/
def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
  if loss_collection:
    ops.add_to_collection(loss_collection, loss)

The losses of ‘losses.sparse_softmax_cross_entropy()’ will be added into collection of ‘GraphKeys.LOSSES’. Then where dose the weight of parameters go ? Will they be added into same collection ? Let’s check. All the layer written by library of ‘tf.layers’ or ‘tf.contrib.slim’ are inherited from ‘class Layer’ and will call ‘add_loss()’ when this layer call ‘add_variable()’. Let’s check ‘add_loss()’ of base class ‘Layer’:

class Layer(checkpointable.CheckpointableBase):
    def add_loss(self, losses, inputs=None):
        _add_elements_to_collection(losses, ops.GraphKeys.REGULARIZATION_LOSSES)

It’s weird. The loss from weight of variable has not been added into ‘GraphKeys.LOSSES’, but ‘GraphKeys.REGULARIZATION_LOSSES’. Then how could we get all the losses at training stage ? After grep ‘REGULARIZATION_LOSSES’ in whole codes of Tensorflow, it comes up with the ‘get_total_loss()’:

# tensorflow/python/ops/losses/
def get_total_loss(add_regularization_losses=True, name="total_loss"):
  losses = get_losses()
  if add_regularization_losses:
    losses += get_regularization_losses()
  return math_ops.add_n(losses, name=name)

That is the secret of losses in ‘tf.layers’ and ‘tf.contrib.slim’: we should use ‘get_total_loss()’ to fetch model loss and regularization loss together!
After changing my code:

cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
    cross_entropy = tf.reduce_mean(cross_entropy)

    global_step = tf.contrib.framework.get_or_create_global_step()

    loss = tf.contrib.slim.losses.get_total_loss()
    train_op = tf.contrib.slim.learning.create_train_op(loss, opt, global_step = global_step)

The ‘weight_decay’ works well now (which means training accuracy could not reach high value easily)

[返回] [原文链接]