#python #tensorflow #keras
#python #тензорный поток #keras
Вопрос:
Могу ли я каким-либо образом присвоить конкретные имена выходным данным подклассовой модели?
Несмотря на то, что мои выходные слои вызываются layername1, layername2
, и я возвращаю выходные данные как a dict
с именами name1, name2
, выходные данные все равно вызываются output_1, output_2
.
Определение модели
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import tensorflow as tf
from tensorflow import math as tm
class TestModel(Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.denselayer = Dense(33, activation='relu')
self.outlayer1 = Dense(1, name='layername1')
self.outlayer2 = Dense(1, name='layername2')
def call(self, inputs):
l_ = self.denselayer(inputs)
out1 = self.outlayer1(l_)
out2 = self.outlayer2(l_)
return {'name1': out1, 'name2': out2}
m = TestModel()
x = tf.zeros((100,100,))
yhat = m(x)
y1 = tf.ones((100,1,))
y2 = 2*tf.ones((100,1,))
m.compile(optimizer='Adam',
#loss = {'layername1': 'mse', 'layername2': 'binary_crossentropy'},
#loss_weights = {'layername1': 1, 'layername2': 1},
loss = ['mse','binary_crossentropy']
)
Это работает:
m.fit(x, [y1, y2])
Это не работает:
dx = tf.data.Dataset.from_tensor_slices(x)
dy = tf.data.Dataset.from_tensor_slices({'name1': y1, 'name2': y2})
xy = tf.data.Dataset.zip((dx, dy)).batch(1)
m.fit(xy)
Это работает:
dy = tf.data.Dataset.from_tensor_slices({'output_1': y1, 'output_2': y2})
xy = tf.data.Dataset.zip((dx, dy)).batch(1)
m.fit(xy)
Комментарии:
1. Что вы подразумеваете под «Это не работает»? Напротив, для меня работает только средний. Если я вызываю
out = m(x) tf.print(out)
, я получаю словарь с ключами ‘name1’ и ‘name2’, как и ожидалось. Первый не работает, потому что модель создает словарь, а вы предоставляете ему массив. Третий случай не соответствует именам ключей в словаре и, следовательно, не может вычислять какие-либо градиенты. Какую версию TensorFlow вы используете? Убедитесь, что у вас есть текущая версия.2. Действительно, это работает с 2.3. Я все еще работал с 2.1.