我想要将mnist个数据集转换为(28, 28, 3)个维度,以便适合tf.keras.applications.MobileNetV2模型,但这个模型需要(x, y, 3)个维度.

https://www.tensorflow.org/tutorials/images/transfer_learning

第一个任务是将mnist (28,28,1)扩展到mnist (28,28,3),然后将(28,28,3)转换为(x,y,3).

下面是显示(28,28,1)张图片的代码:

import numpy as np
import matplotlib.pyplot as plt

x = np.arange(28*28*1).reshape(28,28,1)

plt.figure()
plt.imshow(x)
plt.title(x.shape)
plt.show()

enter image description here

以下代码试图显示(28,28,3),但它是从(28,28,1)转换成NOT的:

y = np.arange(28*28*3).reshape(28,28,3)

plt.figure()
plt.imshow(y)
plt.title(y.shape)
plt.show()

enter image description here

如何将上面的(28,28,1)图像转换成(28, 28, 3)并显示在matplotlib中?

Testing:

以下是比较原始图像(x)、NumPy RGB图像(y)、TensorFlow RGB(z)和填充零图像(pad_zero)的测试:

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

x = np.arange(28*28*1).reshape(28,28,1)
x = x / x.max()

y = np.repeat(x, 3, axis=2)

z = tf.image.grayscale_to_rgb(tf.convert_to_tensor(x)).numpy()

def pad_with_zeros(a):
    a = a.copy()
    for ii, i in enumerate(a):
        for jj, j in enumerate(i):
            for kk, k in enumerate(j):
                if kk != 0:
                    a[ii, jj, kk] = 0
    return a

pad_zero = pad_with_zeros(y)

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
fig.subplots_adjust(wspace=0.1, hspace=0.1)

plt.subplot(1, 4, 1)
plt.imshow(x)
plt.title("x: {}".format(x.shape))

plt.subplot(1, 4, 2)
plt.imshow(y)
plt.title("np.repeat: {}".format(y.shape))

plt.subplot(1, 4, 3)
plt.imshow(z)
plt.title("tf.image: {}".format(z.shape))

plt.subplot(1, 4, 4)
plt.imshow(pad_zero)
plt.title("pad_zero: {}".format(pad_zero.shape))

plt.show()

enter image description here

为什么所有的图像 colored颜色 都是different

yz的 colored颜色 应该看起来像x,但它们是not.有什么不对劲吗?

推荐答案

你可以使用NumPy的repeat函数.

>>> x = np.ones((28, 28, 1))
>>> y = np.repeat(x, 3, axis=2)
>>> y.shape
(28, 28, 3)

Here是这种方法的文档.

Python相关问答推荐

@Property方法上的inspect.getmembers出现意外行为,引发异常

为什么我的Python代码在if-else声明中的行之前执行if-else声明中的行?

try 在树叶 map 上应用覆盖磁贴

Pytest两个具有无限循环和await命令的Deliverc函数

类型错误:输入类型不支持ufuncisnan-在执行Mann-Whitney U测试时[SOLVED]

图像 pyramid .难以创建所需的合成图像

加速Python循环

try 将一行连接到Tensorflow中的矩阵

使用密钥字典重新配置嵌套字典密钥名

python中的解释会在后台调用函数吗?

我的字符串搜索算法的平均时间复杂度和最坏时间复杂度是多少?

使用Python和文件进行模糊输出

基于形状而非距离的两个numpy数组相似性

Numpyro AR(1)均值切换模型抽样不一致性

通过追加列表以极向聚合

Polars map_使用多处理对UDF进行批处理

如何将一组组合框重置回无 Select tkinter?

用来自另一个数据框的列特定标量划分Polars数据框中的每一列,

牛郎星直方图中分类列的设置顺序

PyTorch变压器编码器中的填充掩码问题