下面的代码试图将张量转换为Tensorflow中的(x,y)维array.

使用此代码可以将"a"转换为"b",但"c"不能.

以下是测试代码:

def reshape_array(old_array, x, y):
    new_array = tf.reshape(old_array, [-1])
    
    current_size = tf.size(new_array)
    reshape_size = tf.math.multiply(x, y)
    
    diff = tf.math.subtract(reshape_size, current_size)
    if tf.greater_equal(diff, tf.constant([0])):
        new_array = tf.pad(new_array, [[0,0],[0, diff]], mode='CONSTANT', constant_values=0)
        new_array = tf.reshape(new_array, (x, y))
    else:
        new_array = tf.slice(new_array, begin=[0], size=[reshape_size])
        new_array = tf.reshape(new_array, (x, y))
        
    return tf.cast(new_array, old_array.dtype)

a = tf.zeros(256*192*1)
print("a.shape: {}".format(a.shape))
b = reshape_array(a, 28, 28)
print("b.shape: {}".format(b.shape))

c = tf.constant([1, 2, 3, 4, 5, 6])
print("c.shape: {}".format(c.shape))
d = reshape_array(c, 28, 28)
print("d.shape: {}".format(d.shape))

以下是输出:

a.shape: (49152,)
b.shape: (28, 28)
c.shape: (6,)

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
/tmp/ipykernel_7071/4036910860.py in <cell line: 26>()
     24 c = tf.constant([1, 2, 3, 4, 5, 6])
     25 print("c.shape: {}".format(c.shape))
---> 26 d = reshape_array(c, 28, 28)
     27 print("d.shape: {}".format(d.shape))

/tmp/ipykernel_7071/4036910860.py in reshape_array(old_array, x, y)
      9     diff = tf.math.subtract(reshape_size, current_size)
     10     if tf.greater_equal(diff, tf.constant([0])):
---> 11         new_array = tf.pad(new_array, [[0,0],[0, diff]], mode='CONSTANT', constant_values=0)
     12         new_array = tf.reshape(new_array, (x, y))
     13     else:

/usr/local/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     52   try:
     53     ctx.ensure_initialized()
---> 54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     55                                         inputs, attrs, num_outputs)
     56   except core._NotOkStatusException as e:

InvalidArgumentError: The first dimension of paddings must be the rank of inputs[2,2] [6] [Op:Pad]

我的代码中有什么错误,如何修复?

推荐答案

在第二个示例中,您需要使用1D张量,因此请try :

import tensorflow as tf

def reshape_array(old_array, x, y):
    new_array = tf.reshape(old_array, [-1])
    
    current_size = tf.size(new_array)
    reshape_size = tf.math.multiply(x, y)
    
    diff = tf.math.subtract(reshape_size, current_size)
    if tf.greater_equal(diff, tf.constant([0])):
        print(diff)
        new_array = tf.pad(new_array, [[0, diff]], mode='CONSTANT', constant_values=0)
        new_array = tf.reshape(new_array, (x, y))
    else:
        new_array = tf.slice(new_array, begin=[0], size=[reshape_size])
        new_array = tf.reshape(new_array, (x, y))
        
    return tf.cast(new_array, old_array.dtype)

a = tf.zeros(256*192*1)
print("a.shape: {}".format(a.shape))
b = reshape_array(a, 28, 28)
print("b.shape: {}".format(b.shape))

c = tf.constant([1, 2, 3, 4, 5, 6])
print("c.shape: {}".format(c.shape))
d = reshape_array(c, 28, 28)
print("d.shape: {}".format(d.shape))

在您的情况下,我通常更喜欢使用tf.concat作为填充:

def reshape_array(old_array, x, y):
    new_array = tf.reshape(old_array, [-1])
    
    current_size = tf.size(new_array)
    reshape_size = tf.math.multiply(x, y)
    
    diff = tf.math.subtract(reshape_size, current_size)
    if tf.greater_equal(diff, tf.constant([0])):
        new_array = tf.concat([new_array, tf.repeat([0], repeats=diff)], axis=0)
        new_array = tf.reshape(new_array, (x, y))
    else:
        new_array = tf.slice(new_array, begin=[0], size=[reshape_size])
        new_array = tf.reshape(new_array, (x, y))
        
    return tf.cast(new_array, old_array.dtype)

Python相关问答推荐

Python daskValue错误:无法识别的区块管理器dask -必须是以下之一:[]

从收件箱中的列中删除html格式

2D空间中的反旋算法

在Mac上安装ipython

无法在Docker内部运行Python的Matlab SDK模块,但本地没有问题

SQLAlchemy bindparam在mssql上失败(但在mysql上工作)

如何在Python中获取`Genericums`超级类型?

如何禁用FastAPI应用程序的Swagger UI autodoc中的application/json?

Python pint将1/华氏度转换为1/摄氏度°°

BeautifulSoup:超过24个字符(从a到z)的迭代失败:降低了首次深入了解数据集的复杂性:

如何获取包含`try`外部堆栈的`__traceback__`属性的异常

Django在一个不是ForeignKey的字段上加入'

如何在Python中从html页面中提取html链接?

Polars时间戳同步延迟计算

文本溢出了Kivy的视区

从列表中分离数据的最佳方式

如何在Quarto中的标题页之前创建序言页

为什么在不先将包作为模块导入的情况下相对导入不起作用

如何删除剪裁圆的对角线的外部部分

是否从Python调用SHGetKnownFolderPath?