可以使用model = tf.function(model, jit_compile=True)
启用XLA.通过这种方式,有些型号的速度更快,有些则更慢.到现在为止还好.
但为什么在某些情况下,model = tf.function(model, jit_compile=None)
可以显著提高速度(没有TPU)?
jit_compile
docs个州:
如果为
None
(默认),则在TPU上运行时使用XLA编译函数 并在上运行时通过常规函数执行路径 其他设备.
我在两台非TPU(甚至非GPU)机器(安装了最新的TensorFlow(2.13.0
))上运行我的测试.
import timeit
import numpy as np
import tensorflow as tf
model_plain = tf.keras.applications.efficientnet_v2.EfficientNetV2S()
model_jit_compile_true = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=True)
model_jit_compile_false = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=False)
model_jit_compile_none = tf.function(tf.keras.applications.efficientnet_v2.EfficientNetV2S(), jit_compile=None)
def run(model):
model(np.random.random(size=(1, 384, 384, 3)))
# warmup
run(model_plain)
run(model_jit_compile_true)
run(model_jit_compile_false)
run(model_jit_compile_none)
runs = 10
duration_plain = timeit.timeit(lambda: run(model_plain), number=runs) / runs
duration_jit_compile_true = timeit.timeit(lambda: run(model_jit_compile_true), number=runs) / runs
duration_jit_compile_false = timeit.timeit(lambda: run(model_jit_compile_false), number=runs) / runs
duration_jit_compile_none = timeit.timeit(lambda: run(model_jit_compile_none), number=runs) / runs
print(f"{duration_plain=}")
print(f"{duration_jit_compile_true=}")
print(f"{duration_jit_compile_false=}")
print(f"{duration_jit_compile_none=}")
duration_plain=0.53095479644835
duration_jit_compile_true=1.5860380740836262
duration_jit_compile_false=0.09831228516995907
duration_jit_compile_none=0.09407951850444078