Соответствие API модели keras

#tensorflow #keras

Вопрос:

Я пытаюсь создать что-то похожее на Word2Vec со следующим:

 class Word2Vec(keras.Model):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embedding = keras.layers.Embedding(
                                vocab_size, 
                                embedding_dim,
                                input_length=1,
                                name="w2v_embedding"
        )
        self.dot = keras.layers.Dot(axes=(-1, -1))

    def call(self, data):
        target, context = data
        we = self.embedding(target)
        ce = self.embedding(context)
        return self.dot([we, ce])
 

и предположим, что потеря заключается в следующем:

 def loss(similarity):
    log_prob = tf.math.log(tf.sigmoid(similarity))
    return -tf.math.reduce_mean(log_prob)
 

Я пытаюсь соответствовать приведенной выше модели со словами и их контекстами, но сталкиваюсь с ошибкой: OperatorNotAllowedInGraphError: iterating over tf.Тензор is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature. .

Предположим, у меня есть фиктивный набор данных, который выглядит следующим образом:

 N = 10000
V = 100
word = np.random.randint(0, V, N)
context = np.random.randint(0, V, (N, 4))
 

То, что я пытался сделать, было:

 word2vec = Word2Vec(V, 32)
word2vec.compile(loss=loss, optimizer="adam")
word2vec.fit(tf.data.Dataset.from_tensor_slices((word, context)), batch_size=128, epochs=1)
 

когда я получил вышеуказанную ошибку. Есть какие-нибудь мысли о том, как это исправить?

Я понимаю, что это не точная модель word2vec, но меня больше волнует понимание API tensorflow/ keras и то, как это работает, чем фактическая реализация на бумаге.

Правка 1

Редактируемый блокнот kaggle с полным кодом доступен здесь: https://www.kaggle.com/sachin/word-vectors

Комментарии:

1. ваша потеря, похоже, не включает цель. это должна быть потеря(цель, сходство)… Я предлагаю вам это adventuresinmachinelearning.com/word2vec-keras-tutorial

2. Да, я понимаю, но в данном конкретном случае все цели верны. Следовательно, почему я использовал функцию пользовательских потерь выше.

Ответ №1:

Я думаю, что это вызывает проблему:

 target, context = data
 

Попробуйте это вместо этого:

 target = data[0]
context = data[1]
 

Комментарии:

1. Боюсь, это не сработало. Я добавил блокнот каггла выше, если это поможет.