admin管理员组

文章数量:1576378

Tensorflow2 Bug: triggered tf.function retracing

  • 背景
  • 尝试
  • References

背景

Bug在通过tf.function模式下运行时出现,应该与AutoGraph模式本身的使用机制有关, tensorflow中关于该Bug的帖子:

本人tensorflow2也用了快一年了,之前也写过好多相关代码/项目,都没有出现过该问题,这次突然出现这个问题, 好郁闷!

尝试

1.根据提示, tf.function具有experimental_relax_shapes = True选项,该选项可放宽参数形状,从而避免不必要的跟踪, 加上该参数后,即@tf.function(experimental_relax_shapes = True)发现并不管用,这就说明,并不是这个原因。实际我的程序中每个batch的尺寸都是固定的。
2.参考博客:https://blog.csdn/xygl2009/article/details/104443654, 通过tf.functioninput_signature参数来给方法的每个参数定义signature,如下:试了下,还是不管用。

train_step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
	xxx  

3.最终我发现了问题所在:
@tf.function()
def train_step(batch_id: int) :
xxxx
函数train_step中由一个int类型的参数,然后利用该int参数来生成Tensor类型变量,并参与计算。这会导致一个问题:int参数是Python参数,在这些情况下,Python值的变化可能会触发不必要的回溯。举例来说,遍历batch循环训练,AutoGraph将动态展开。换句话说,每次call train_step时都会触发回溯,从而大大延迟程序执行,出现这个警告
解决方式: 通常,Python的参数被用来控制超参数和图的结构-例如,num_layers=10或training=True或nonlinearity=‘relu’。只是这样的话,并没有什么问题,我之前就这样用过, 而在这里,由于Python的参数参与了图的计算, 再用tf.function一静态图的方式运行时就会出现这个问题。然后我将int类型的参数在传入是转换为Tensor格式, 问题自动消失了。

References

1.https://www.tensorflow/api_docs/python/tf/function
2.https://www.tensorflow/tutorials/customization/performance#python_or_tensor_args
3.https://stackoverflow/questions/61647404/tensorflow-2-getting-warningtensorflow9-out-of-the-last-9-calls-to-function
4.

本文标签: triggeredBugretracingfunctionTF