我最近发现了一种方法,通过@AloneTogether在this answer中以非常干净的方式完成:
import tensorflow as tf
data_tensor = tf.constant([3,5,6,2,6,1,3,9,5])
mask_tensor = tf.constant([0,1,1,1,0,0,1,1,0])
# Index where the mask changes.
change_idx = tf.concat([tf.where(mask_tensor[:-1] != mask_tensor[1:])[:, 0], [tf.shape(mask_tensor)[0]-1]], axis=0)
# Ranges of indices to gather.
ragged_idx = tf.ragged.range(tf.concat([[0], change_idx[:-1] + 1], axis=0), change_idx + 1)
# Gather ranges into ragged tensor.
output_tensor = tf.gather(data_tensor, ragged_idx)
print(output_tensor)
<tf.RaggedTensor [[3], [5, 6, 2], [6, 1], [3, 9], [5]]>