#python #tensorflow #conv-neural-network
Вопрос:
Я создаю сверточную нейронную сеть на python. Чтобы получить сверточный слой, я пытаюсь сделать это:
def convLayer(x, kHeight, kWidth, strideX, strideY, featureNum, name, padding = "SAME", groups = 1): """convolution""" channel = int(X_train_tf.get_shape()[-1]) conv = lambda a, b: tf.nn.conv2d(a, b, strides = [1, strideY, strideX, 1], padding = padding) conv.double() with tf.compat.v1.variable_scope(name) as scope: w = tf.compat.v1.get_variable("w", shape = [kHeight, kWidth, int(channel/groups), featureNum]) b = tf.compat.v1.get_variable("b", shape = [featureNum]) xNew = tf.split(value = X_train_tf, num_or_size_splits = groups, axis = 3) wNew = tf.split(value = w, num_or_size_splits = groups, axis = 3) featureMap = [conv(t1, t2) for t1, t2 in zip(xNew, wNew)] mergeFeatureMap = tf.concat(axis = 3, values = featureMap) print mergeFeatureMap.shape out = tf.nn.bias_add(mergeFeatureMap, b) return tf.nn.relu(tf.reshape(out, mergeFeatureMap.get_shape().as_list()), name = scope.name)
Проблема в том, что FeatureMap = [conv(t1, t2) для t1, t2 в zip(xNew, wNew)]
потому что conv ожидал двойного тензора и получил тензор с плавающей точкой. Я попытался изменить t1 и t2 ( я назвал x и y) с помощью:
for x, y in zip(xNew, wNew): x.DoubleTensor() (x,y)= torch.from_numpy(x,y).double() y = torch.from_numpy(y).double() x,y=x.type(torch.DoubleTensor),y.type(torch.DoubleTensor)
Но любая часть этого кода работает. самая распространенная ошибка-это
объект «tensorflow.python.framework.ops.EagerTensor» не имеет атрибута» тип » / «двойной» … и т. Д.
У кого-нибудь есть решение для этого? Спасибо.
Комментарии:
1. Пожалуйста, укажите строку, в которой вы получили ошибку .
2. Вам следует избегать использования tf.compat.v1, это может привести к ошибкам, и это бесполезно для последних версий tf
3. ошибка в строке 16 «Карта функций = [conv(t1, t2) для t1, t2 в zip(xNew, wNew)]» проблема заключается в аргументах conv. Я обнаружил, что использование tf.compat.v1 для работы с get_variable, потому что оно не дополнялось без tf.compat.v1.