了解tf.extract_image_patches从图像中提取补丁
我找到了以下方法 tf.extract_image_patches 在tensorflow API中,但我不清楚其功能.
I found the following method tf.extract_image_patches in tensorflow API, but I am not clear about its functionality.
假设batch_size = 1
,并且图像的大小为225x225x3
,我们要提取大小为32x32
的色块.
Say the batch_size = 1
, and an image is of size 225x225x3
, and we want to extract patches of size 32x32
.
此功能的行为如何?具体来说,文档中提到输出张量的尺寸为[batch, out_rows, out_cols, ksize_rows * ksize_cols * depth]
,但未提及out_rows
和out_cols
.
How exactly does this function behave? Specifically, the documentation mentions the dimension of the output tensor to be [batch, out_rows, out_cols, ksize_rows * ksize_cols * depth]
, but what out_rows
and out_cols
are is not mentioned.
理想情况下,给定大小为1x225x225x3
的输入图像张量(其中1是批处理大小),我希望能够获得Kx32x32x3
作为输出,其中K
是补丁的总数,而
Ideally, given an input image tensor of size 1x225x225x3
(where 1 is the batch size), I want to be able to get Kx32x32x3
as output, where K
is the total number of patches and 32x32x3
is the dimension of each patch. Is there something in tensorflow that already achieves this?
该方法的工作原理如下:
Here is how the method works:
-
ksizes
用于确定每个补丁的尺寸,即每个补丁应包含多少像素. -
strides
表示原始图像中一个小块的起点与下一个连续小块的起点之间的间隙长度. -
rates
是一个数字,从本质上讲,对于在补丁中结束的每个连续像素,补丁应在原始图像中跳出rates
像素. (下面的示例有助于说明这一点.) -
padding
是"VALID"(有效),这意味着每个色块必须完全包含在图像中;或者是"SAME",这意味着允许色块不完整(其余像素将用零填充). /li>
-
ksizes
is used to decide the dimensions of each patch, or in other words, how many pixels each patch should contain. -
strides
denotes the length of the gap between the start of one patch and the start of the next consecutive patch within the original image. -
rates
is a number that essentially means our patch should jump byrates
pixels in the original image for each consecutive pixel that ends up in our patch. (The example below helps illustrate this.) -
padding
is either "VALID", which means every patch must be fully contained in the image, or "SAME", which means patches are allowed to be incomplete (the remaining pixels will be filled in with zeroes).
下面是一些带有输出的示例代码,以帮助演示其工作原理:
Here is some sample code with output to help demonstrate how it works:
import tensorflow as tf
n = 10
# images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100 in order
images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]
# We generate four outputs as follows:
# 1. 3x3 patches with stride length 5
# 2. Same as above, but the rate is increased to 2
# 3. 4x4 patches with stride length 7; only one patch should be generated
# 4. Same as above, but with padding set to 'SAME'
with tf.Session() as sess:
print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n'
print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 2, 2, 1], padding='VALID').eval(), '\n\n'
print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n'
print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='SAME').eval()
输出:
[[[[ 1 2 3 11 12 13 21 22 23]
[ 6 7 8 16 17 18 26 27 28]]
[[51 52 53 61 62 63 71 72 73]
[56 57 58 66 67 68 76 77 78]]]]
[[[[ 1 3 5 21 23 25 41 43 45]
[ 6 8 10 26 28 30 46 48 50]]
[[ 51 53 55 71 73 75 91 93 95]
[ 56 58 60 76 78 80 96 98 100]]]]
[[[[ 1 2 3 4 11 12 13 14 21 22 23 24 31 32 33 34]]]]
[[[[ 1 2 3 4 11 12 13 14 21 22 23 24 31 32 33 34]
[ 8 9 10 0 18 19 20 0 28 29 30 0 38 39 40 0]]
[[ 71 72 73 74 81 82 83 84 91 92 93 94 0 0 0 0]
[ 78 79 80 0 88 89 90 0 98 99 100 0 0 0 0 0]]]]
例如,我们的第一个结果如下所示:
So, for example, our first result looks like the following:
* * * 4 5 * * * 9 10
* * * 14 15 * * * 19 20
* * * 24 25 * * * 29 30
31 32 33 34 35 36 37 38 39 40
41 42 43 44 45 46 47 48 49 50
* * * 54 55 * * * 59 60
* * * 64 65 * * * 69 70
* * * 74 75 * * * 79 80
81 82 83 84 85 86 87 88 89 90
91 92 93 94 95 96 97 98 99 100
如您所见,我们有2行和2列的补丁程序,分别是out_rows
和out_cols
.
As you can see, we have 2 rows and 2 columns worth of patches, which are what out_rows
and out_cols
are.