tf.shape(x)、x.shape、x.get_shape()函数解析(最清晰的解释) 欢迎关注WX公众号:【程序员管小亮】

tf.shape(x)、x.shape、x.get_shape()函数解析(最清晰的解释)
欢迎关注WX公众号:【程序员管小亮】

最近看到了tf.shape(x)、x.shape和x.get_shape()三个函数,不知道他们的差别,所以记录一下。

import tensorflow as tf

x = tf.constant([[0,1,2],[3,4,5],[6,7,8]])

print(type(x.shape))
print(type(x.get_shape()))
print(type(tf.shape(x)))
> <class 'tensorflow.python.framework.tensor_shape.TensorShapeV1'>
> <class 'tensorflow.python.framework.tensor_shape.TensorShapeV1'>
> <class 'tensorflow.python.framework.ops.Tensor'>

可以看到s.shape和x.get_shape()都是返回TensorShapeV1类型对象,而tf.shape(x)返回的是Tensor类型对象。

除此之外,对tf.shape(x)来说,其中x可以是tensor,也可不是tensor,返回是一个tensor。而对x.get_shape()来说,只有tensor有这个方法, 返回是一个tuple。

所以,如果在运行下面代码的时候,

x = tf.placeholder(tf.float32, shape=[None, 227] )

想知道None到底是多少,这时候,只能通过tf.shape(x)[0]这种方式来获得。

而想要获得维度信息,则需要调用前两种方法。

import tensorflow as tf

x = tf.constant([[0,1,2],[3,4,5],[6,7,8]])

print(x.shape)
print(x.get_shape())
print(tf.shape(x))
print(tf.rank(x))
> (3, 3)
> (3, 3)
> Tensor("Shape_3:0", shape=(2,), dtype=int32)
> Tensor("Rank_2:0", shape=(), dtype=int32)

或者是调用ts.as_list()方法,返回的是Python的list。

import tensorflow as tf

x = tf.constant([[0,1,2],[3,4,5],[6,7,8]])

x.shape.as_list()
#x.get_shape().as_list()
> [3, 3]

python课程推荐。
tf.shape(x)、x.shape、x.get_shape()函数解析(最清晰的解释)
欢迎关注WX公众号:【程序员管小亮】