#python #numpy #tensorflow #deep-learning #tensor
Вопрос:
Я просмотрел тонны статей, в которых говорилось, как энергичное выполнение может помочь мне перебрать тензор, но, похоже, это не работает для меня. Я намерен создать функцию потерь, для которой мне нужно выполнить итерацию по тензору.
def lightness_loss_non_white(predictions, targets):
tf.contrib.eager.enable_eager_execution()
total_elements = (tf.shape(targets)[0] * tf.shape(targets)[1] * tf.shape(targets)[2]
* tf.shape(targets)[3])
#total_elements = tf.to_float(total_elements)
predictions1 = predictions.numpy()
targets1 = targets.numpy()
for i in range(4):
for img,img1 in (predictions1[:,:,:,i*3:(i 1)*3],targets1[:,:,:,i*3:(i 1)*3]):
for col in range(256):
for i in range(256):
if(img[col][i] == [255,255,255] and img1[col][i] == [255,255,255]): # If both are white then we will not consider it
total_elements -= 3
total_elements = tf.to_float(total_elements)
total_loss = tf.div(0.0,total_elements)
"""
predictions = predictions.eval(session=tf.compat.v1.Session())
targets = targets.eval(session=tf.compat.v1.Session())
"""
for i in range(4): #num_out
pred_lig = tf.div(tf.math.reduce_max(predictions[:,:,:,i*3:(i 1)*3],axis=3) tf.math.reduce_min(predictions[:,:,:,i*3:(i 1)*3],axis=3),2)
target_lig = tf.div(tf.math.reduce_max(targets[:,:,:,i*3:(i 1)*3],axis=3) tf.math.reduce_min(targets[:,:,:,i*3:(i 1)*3],axis=3),2)
loss = tf.reduce_sum(tf.square(pred_lig-target_lig))
loss = tf.div(loss, total_elements)
total_loss =loss
return total_loss
Не беспокойтесь обо всем коде. Просто посмотрите на ту часть, где требуется повторение. Я хотел преобразовать прогнозы в массив numpy, который легко повторяется. Я включил нетерпеливое выполнение, но оно все еще не работает. Кроме того, прогнозы и цели являются 4-мерными тензорами, если это помогает, которые, в свою очередь, являются просто изображениями, сложенными вместе.
- Форма предсказаний такова (8,256,256,12)
- Форма целей (? , 256,256,12)
Ошибка возникает из-за преобразования самого тензора предсказаний!
PS: Используемая версия tensorflow-1.x