假设我以这种方式定义了一个数据集:
filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))
如何获得数据集中的元素数(因此,组成一个历元的单个元素数)?
我知道tf.data.Dataset
已经知道数据集的维度,因为repeat()
方法允许在指定的时间段内重复输入管道.所以这一定是一种获取信息的方式.
假设我以这种方式定义了一个数据集:
filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))
如何获得数据集中的元素数(因此,组成一个历元的单个元素数)?
我知道tf.data.Dataset
已经知道数据集的维度,因为repeat()
方法允许在指定的时间段内重复输入管道.所以这一定是一种获取信息的方式.
tf.data.Dataset.list_files
创建一个名为MatchingFiles:0
的张量(如果适用,带有适当的前缀).
你可以判断
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]
获取文件的数量.
当然,这只适用于简单的情况,尤其是如果每个图像只有一个样本(或已知数量的样本).
在更复杂的情况下,例如,当您不知道每个文件中的样本数时,只能在历元结束时观察样本数.
要做到这一点,你可以观看以你的Dataset
为单位的纪元数.repeat()
创建一个名为_count
的成员,该成员计算历代的数量.通过在迭代过程中观察它,您可以发现它何时发生变化,并从中计算数据集的大小.
这个计数器可能被埋在连续调用成员函数时创建的Dataset
的层次 struct 中,所以我们必须像这样把它挖出来.
d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround
RepeatDataset = type(tf.data.Dataset().repeat())
try:
while not isinstance(d, RepeatDataset):
d = d._input_dataset
except AttributeError:
warnings.warn('no epoch counter found')
epoch_counter = None
else:
epoch_counter = d._count
请注意,使用这种技术,数据集大小的计算是不精确的,因为增加epoch_counter
的批次通常会混合来自两个连续时期的样本.所以这个计算精确到你的批次长度.