我正在学习Udacity的Intro to TensorFlow for Deep Learning门课程.其中一节课是通过一个Google Colab笔记本来学习如何使用Fashion MNIST数据集(ColabGitHub个链接).

显示数据集中的图像时,代码总是使用NumPy方法take()提取第一个图像.我想知道如何访问数据集的不同部分,但我对Python的知识还不够.

第一个例子是:

# Take a single image, and remove the color dimension by reshaping
for image, label in test_dataset.take(1):
  break;

image = image.numpy().reshape((28,28))

我该怎么做才能只拿第二个?

通过阅读this question的一些答案,我发现了skip()方法,但它似乎需要所有元素,直到结束,但我只想提取几个项目,在本例中仅提取一个.

第二个例子是:

plt.figure(figsize=(10,10))
for i, (image, label) in enumerate(train_dataset.take(25)):
    image = image.numpy().reshape((28,28))
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(image, cmap=plt.cm.binary)
    plt.xlabel(class_names[label])
plt.show()

这一条包含了前25条:我如何开始阅读第100条的图像?

在第三个也是最后一个例子中,他们做了一件非常相似的事情:

for test_images, test_labels in test_dataset.take(1):
  test_images = test_images.numpy()
  test_labels = test_labels.numpy()
  predictions = model.predict(test_images)

如何从第二批或第三批数据开始?

推荐答案

要在第100次之后获取项目,您可以执行以下操作:

  1. 设置batch_size = 100
  2. 使用此行:dataset = dataset.skip(1).take(1).这一行跳过前dataset = dataset.skip(1).take(1)个元素,从(because batch_size == 100)开始到200 (because batch_size == 100).

示例代码:

import tensorflow_datasets as tfds
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np


dataset = tfds.load('fashion_mnist', as_supervised=True, split = 'train').batch(100)
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
for iter, (image, label) in enumerate(dataset.skip(1).take(2)):
    print(f"{1*(iter+1)}01 -> {1*(iter+1)}10 images")
    fig, axes = plt.subplots(2,5,figsize=(15,6))
    for idx, axe in enumerate(axes.flatten()):
        axe.axis('off')
        axe.imshow(image[idx][...,0])
        axe.set_title(class_names[label[idx]])
    plt.show()
    print()

输出:

enter image description here

enter image description here

Python相关问答推荐

Pydantic 2.7.0模型接受字符串日期时间或无

Python daskValue错误:无法识别的区块管理器dask -必须是以下之一:[]

如何使用symy打印方程?

PywinAuto在Windows 11上引发了Memory错误,但在Windows 10上未引发

Pandas 都是(),但有一个门槛

按顺序合并2个词典列表

如何使用数组的最小条目拆分数组

Pandas DataFrame中行之间的差异

让函数调用方程

在两极中过滤

找到相对于列表索引的当前最大值列表""

解决Geopandas和Altair中的正图和投影问题

有没有办法让Re.Sub报告它所做的每一次替换?

如何设置nan值为numpy数组多条件

如何提高Pandas DataFrame中随机列 Select 和分配的效率?

用0填充没有覆盖范围的垃圾箱

如何在Polars中创建条件增量列?

来自任务调度程序的作为系统的Python文件

GEKKO中若干参数的线性插值动态优化

如何将参数名作为参数传入到函数中?