我自己写的类似tensorflow的conv2d函数对padding=SAME的场景无法通过测试
问题描述:
#现象
自己实现的conv2d在padding=same的参数下无法通过与tensorflow内部conv2d函数的对比测试,但是对padding=valid的可以
#我的代码
def conv2d(input, filter, stride, padding):
# batch x height x width x channels
in_s = input.shape
print("in_s:", in_s)
# height x width x in_channels x out_channels
f_s = filter.shape
print("f_s:", f_s)
assert len(in_s) == 4, 'input size rank 4 required!'
assert len(f_s) == 4, 'filter size rank 4 required!'
assert f_s[2] == in_s[3], 'intput channels not match filter channels.'
assert f_s[0] >= stride and f_s[1] >= stride, 'filter should not be less than stride!'
assert padding in [
'SAME', 'VALID'], 'padding value[{0}] not allowded!!'.format(padding)
if padding != 'VALID':
# 提示: 关于SAME和VALID的区别,请参考:https://www.tensorflow.org/api_docs/python/tf/nn/convolution
##################
# Your code here #
# 计算前后补零的个数
if f_s[0] % 2 == 0: # 卷积核尺寸为偶数
height_front_padding = (f_s[0] // 2) - 1
height_end_padding = (f_s[0] // 2)
else: # 卷积核尺寸为奇数
height_front_padding = height_end_padding = (f_s[0] - 1) // 2
if f_s[1] % 2 == 0: # 卷积核尺寸为偶数
width_front_padding = (f_s[1] // 2) - 1
width_end_padding = (f_s[1] // 2)
else: # 卷积核尺寸为奇数
width_front_padding = width_end_padding = (f_s[1] - 1) // 2
print("padding: ", height_front_padding, height_end_padding, width_front_padding, width_end_padding)
# 开一个空间、容器
new_height = in_s[1] + height_front_padding + height_end_padding
new_width = in_s[2] + width_front_padding + width_end_padding
temp = np.zeros((in_s[0], new_height, new_width, in_s[3]), dtype=np.float)
# 开始补零
for b in range(in_s[0]):
for c in range(in_s[3]):
for i in range(in_s[1]):
for j in range(in_s[2]):
temp[b][height_front_padding + i][width_front_padding + j][c] = input[b][i][j][c]
##################
input = temp
in_s = input.shape
out_shape = (np.array(in_s[1: 3]) -
np.array(f_s[:2])) // stride + 1
out_shape = np.concatenate([in_s[:1], out_shape, f_s[-1:]])
output = np.zeros(out_shape)
##################
# Your code here #
# 开始做卷积
for b in range(in_s[0]):
for f in range(f_s[3]):
for i in range(out_shape[1]):
for j in range(out_shape[2]):
# 一次卷积乘的过程
output[b][i][j][f] = (input[b, i * stride:i * stride + f_s[0], j * stride:j * stride + f_s[1]] * filter[:, :, :, f]).sum()
##################
return output
#测试代码
# 先定义个计算图用于运行tf
input_tensor = tf.placeholder(
tf.float32, shape=[None, None, None, None], name='input')
filter_tensor = tf.placeholder(
tf.float32, shape=[None, None, None, None], name='filter')
output_tensor1 = tf.nn.conv2d(
input_tensor, filter_tensor, padding='SAME', strides=[1, 2, 2, 1])
output_tensor2 = tf.nn.conv2d(
input_tensor, filter_tensor, padding='VALID', strides=[1, 3, 3, 1])
output_tensor3 = tf.nn.conv2d(
input_tensor, filter_tensor, padding='SAME', strides=[1, 2, 2, 1])
try:
final_score = 0 # 这个是最终得分
filter = np.random.uniform(size=[5, 5, 3, 8])
output = conv2d(img, filter, 2, 'SAME')
with tf.Session() as sess:
output_tf = sess.run(
output_tensor3,
feed_dict={
input_tensor: img,
filter_tensor: filter
})
print("mine: ", output[0, 1, :, 0])
print("goal: ", output_tf[0, 1, :, 0])
assert output.shape == output_tf.shape, 'shape mismatch [{}] vs [{}]'.format(
output.shape, output_tf.shape)
final_score += 20 # shape算对了得20分
diff = np.mean(np.abs(output - output_tf))
assert diff < 1e-5, 'value mismatch [{}]'.format(
diff) # 如果这一行没有报错的话,那么实现可以认为是正确的。
final_score += 30 # 数值算对了得30分
print('test 1 passed...')
filter = np.random.uniform(size=[5, 5, 3, 8])
output = conv2d(img, filter, 3, 'VALID')
with tf.Session() as sess:
output_tf = sess.run(
output_tensor2,
feed_dict={
input_tensor: img,
filter_tensor: filter
})
assert output.shape == output_tf.shape, 'shape mismatch [{}] vs [{}]'.format(
output.shape, output_tf.shape)
final_score += 20 # shape算对了得20分
diff = np.mean(np.abs(output - output_tf))
assert diff < 1e-5, 'value mismatch [{}]'.format(
diff) # 如果这一行没有报错的话,那么实现可以认为是正确的。
final_score += 30 # 数值算对了得30分
print('test 2 passed...')
except Exception as ex:
print(ex)
print('Your final score:[{}]'.format(final_score))
#测试运行结果
答
你好,我是有问必答小助手,非常抱歉,本次您提出的有问必答问题,技术专家团超时未为您做出解答
本次提问扣除的有问必答次数,将会以问答VIP体验卡(1次有问必答机会、商城购买实体图书享受95折优惠)的形式为您补发到账户。
因为有问必答VIP体验卡有效期仅有1天,您在需要使用的时候【私信】联系我,我会为您补发。