От тензора с плавающей точкой к двойному тензору

#python #tensorflow #conv-neural-network

Вопрос:

Я создаю сверточную нейронную сеть на python. Чтобы получить сверточный слой, я пытаюсь сделать это:

 def convLayer(x, kHeight, kWidth, strideX, strideY,  featureNum, name, padding = "SAME", groups = 1):  """convolution""" channel = int(X_train_tf.get_shape()[-1]) conv = lambda a, b: tf.nn.conv2d(a, b, strides = [1, strideY, strideX, 1], padding = padding)  conv.double()  with tf.compat.v1.variable_scope(name) as scope:  w = tf.compat.v1.get_variable("w", shape = [kHeight, kWidth, int(channel/groups), featureNum])  b = tf.compat.v1.get_variable("b", shape = [featureNum])    xNew = tf.split(value = X_train_tf, num_or_size_splits = groups, axis = 3)  wNew = tf.split(value = w, num_or_size_splits = groups, axis = 3)    featureMap = [conv(t1, t2) for t1, t2 in zip(xNew, wNew)]  mergeFeatureMap = tf.concat(axis = 3, values = featureMap)  print mergeFeatureMap.shape  out = tf.nn.bias_add(mergeFeatureMap, b) return tf.nn.relu(tf.reshape(out, mergeFeatureMap.get_shape().as_list()), name = scope.name)  

Проблема в том, что FeatureMap = [conv(t1, t2) для t1, t2 в zip(xNew, wNew)]

потому что conv ожидал двойного тензора и получил тензор с плавающей точкой. Я попытался изменить t1 и t2 ( я назвал x и y) с помощью:

 for x, y in zip(xNew, wNew):  x.DoubleTensor()  (x,y)= torch.from_numpy(x,y).double()  y = torch.from_numpy(y).double()  x,y=x.type(torch.DoubleTensor),y.type(torch.DoubleTensor)  

Но любая часть этого кода работает. самая распространенная ошибка-это

объект «tensorflow.python.framework.ops.EagerTensor» не имеет атрибута» тип » / «двойной» … и т. Д.

У кого-нибудь есть решение для этого? Спасибо.

Комментарии:

1. Пожалуйста, укажите строку, в которой вы получили ошибку .

2. Вам следует избегать использования tf.compat.v1, это может привести к ошибкам, и это бесполезно для последних версий tf

3. ошибка в строке 16 «Карта функций = [conv(t1, t2) для t1, t2 в zip(xNew, wNew)]» проблема заключается в аргументах conv. Я обнаружил, что использование tf.compat.v1 для работы с get_variable, потому что оно не дополнялось без tf.compat.v1.