NumPy reshape()方法详解

TensorFlow中利用张量来表示数据,张量是有阶数的,一阶张量就是一个一维向量,二阶张量是矩阵。NumPy中的shape就是张量的形状,用元组(tuple)表示,元组中元素表示相应维的大小。NumPy中reshape方法用来重组/改变张量的形状,是比较常用的方法。

reshape函数定义

reshape函数定义如下:

numpy.reshape(a, newshape, order='C')

前两个参数很好理解,分别指定了需要改变的张量及新的形状。第三个参数指定改变形状过程中的检索及放置元素的顺序。NumPy的文档中对第三个参数解释如下:

order : {‘C’, ‘F’, ‘A’}, optional

Read the elements of a using this index order, and place the elements into the reshaped array using this index order. ‘C’ means to read / write the elements using C-like index order, with the last axis index changing fastest, back to the first axis index changing slowest. ‘F’ means to read / write the elements using Fortran-like index order, with the first index changing fastest, and the last index changing slowest. Note that the ‘C’ and ‘F’ options take no account of the memory layout of the underlying array, and only refer to the order of indexing. ‘A’ means to read / write the elements in Fortran-like index order if a is Fortran contiguous in memory, C-like order otherwise.

其中,'A'表示自动选择顺序,所以顺序其实只有两种:'C'和'F'。

  • C顺序:指类似c语言的多维数组存储顺序,先遍历数组高维的索引;
  • F顺序:指类似Fortran语言的多维数组存储顺序,先遍历数组低维的索引;

下面我们通过实例来理解这两种顺序。在NumPy中,函数flatten也用到了C顺序和F顺序。

flatten函数

flatten函数用来将多维数组展开为一维数组,多维数组中的元素在一维数组中如何排列正是由C顺序/F顺序决定的。 C顺序中,先变化高维的索引,对于一个2X3的数组A来说,它展开为一维的顺序是:

A[0][0], A[0][1], A[0][2], A[1][0], A[1][1], A[1][2]

F顺序中,先变化低维的索引,对于一个2X3的数组A来说,它展开为一维的顺序是:

A[0][0], A[1][0], A[0][1], A[1][1], A[0][2], A[1][2]

简单来说,C顺序展开就是把数组按行展开,一行接一行;而F顺序展开就是按列展开,一列接一列。

>>> import numpy as np
>>> m = np.array([[1,2,3],[4,5,6]])
>>> m.flatten('C')
array([1, 2, 3, 4, 5, 6])
>>> m.flatten('F')
array([1, 4, 2, 5, 3, 6])

reshape函数如何工作

理解了C顺序和F顺序后,我们来看reshape是依赖于顺序(order)参数的。reshape的工作逻辑上可以分为两部分:

  • 将原数组按指定顺序展开为一维数组;
  • 将一维数组按排列顺序依次填入到新数组,多维数组的填入顺序依赖于指定的顺序;

实际上第一部分就是一个flatten的过程,第二部分如果是C顺序,则新数组按行填入;如果是F顺序,则新数组按列填入。

下面的例子中将2X3的矩阵变化为3X2的矩阵:

>>> import numpy as np
>>> >>> m = np.array([[1,2,3],[4,5,6]])
>>> m.shape
(2, 3)
>>> m1 = np.reshape(m, (3,2), 'C')
>>> m1
array([[1, 2],
       [3, 4],
       [5, 6]])
>>> m1.shape
(3, 2)
>>> m2 = np.reshape(m, (3,2), 'F')
>>> m2
array([[1, 5],
       [4, 3],
       [2, 6]])
>>> m2.shape
(3, 2)
>>>

reshape函数不改变存储

这里需要注意的是,reshape函数生成的新数组和原始数组是共用同一块内存的,也就是新旧数组对数据的修改会相互影响。

Python 3.7.1 (default, Dec 14 2018, 19:28:38)
>>> import numpy as np
>>> m = np.array([[1,2,3],[4,5,6]])
>>> m1 = np.reshape(m, (3,2))
>>> m1
array([[1, 2],
       [3, 4],
       [5, 6]])
>>> m[0][0] = 99
>>> m1
array([[99,  2],
       [ 3,  4],
       [ 5,  6]])
>>> m1[2][1]=0
>>> m
array([[99,  2,  3],
       [ 4,  5,  0]])
>>>

小结

reshape函数中顺序的缺省参数是'C',即按照类c语言的方式处理,大多数情况下我们不太需要关心顺序参数,但是也需要明白它的工作原理。 在NumPy或者其它机器学习、科学计算库中,经常会遇到诸如reshape(m,(-1,28,28))之类的表达,这里其中一个新维度是-1,实际上是告知NumPy库,该维度数量由NumPy根据原数组以及已指定的维度计算。