admin管理员组文章数量:1576378
Tensorflow2 Bug: triggered tf.function retracing
- 背景
- 尝试
- References
背景
该Bug在通过tf.function模式下运行时出现,应该与AutoGraph模式本身的使用机制有关, tensorflow中关于该Bug的帖子:
本人tensorflow2也用了快一年了,之前也写过好多相关代码/项目,都没有出现过该问题,这次突然出现这个问题, 好郁闷!
尝试
1.根据提示, tf.function
具有experimental_relax_shapes = True
选项,该选项可放宽参数形状,从而避免不必要的跟踪, 加上该参数后,即@tf.function(experimental_relax_shapes = True)
发现并不管用,这就说明,并不是这个原因。实际我的程序中每个batch的尺寸都是固定的。
2.参考博客:https://blog.csdn/xygl2009/article/details/104443654, 通过tf.function的input_signature参数来给方法的每个参数定义signature,如下:试了下,还是不管用。
train_step_signature = [
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
tf.TensorSpec(shape=(None, None), dtype=tf.int64),
]
@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
xxx
3.最终我发现了问题所在:
@tf.function()
def train_step(batch_id: int) :
xxxx
函数train_step中由一个int类型的参数,然后利用该int参数来生成Tensor类型变量,并参与计算。这会导致一个问题:int参数是Python参数,在这些情况下,Python值的变化可能会触发不必要的回溯。举例来说,遍历batch循环训练,AutoGraph将动态展开。换句话说,每次call train_step时都会触发回溯,从而大大延迟程序执行,出现这个警告
。
解决方式: 通常,Python的参数被用来控制超参数和图的结构-例如,num_layers=10或training=True或nonlinearity=‘relu’。只是这样的话,并没有什么问题,我之前就这样用过, 而在这里,由于Python的参数参与了图的计算, 再用tf.function一静态图的方式运行时就会出现这个问题。然后我将int类型的参数在传入是转换为Tensor格式
, 问题自动消失了。
References
1.https://www.tensorflow/api_docs/python/tf/function
2.https://www.tensorflow/tutorials/customization/performance#python_or_tensor_args
3.https://stackoverflow/questions/61647404/tensorflow-2-getting-warningtensorflow9-out-of-the-last-9-calls-to-function
4.
本文标签: triggeredBugretracingfunctionTF
版权声明:本文标题:Tensorflow2 Bug:triggered tf.function retracing (已解决) 内容由热心网友自发贡献,该文观点仅代表作者本人, 转载请联系作者并注明出处:https://m.elefans.com/dianzi/1727798320a1130527.html, 本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。
发表评论