tf.sparse.reshape(tf.sparse.split()) : ошибка типа: ввод должен быть SparseTensor

#python #tensorflow #tensorflow2.0

#python #тензорный поток #tensorflow2.0

Вопрос:

Я пытаюсь преобразовать плотную матрицу в вычисление разреженной матрицы в tensorflow. При попытке reshape после использования возникает ошибка tf.sparse.split( . Ниже приведен игрушечный пример, демонстрирующий проблему.

Плотной матрицей тензорного потока

 import numpy as np
import tensorflow as tf
a = np.array([[1, 0, 2, 0,0,1], [3, 0, 0, 4,1,0]])

a_t = tf.constant(a)
a_t_rshp = tf.reshape(tf.split(a_t,2,axis = 1),[2,2,3])
 

Разреженной матрицей тензорного потока

 a_t_st = tf.sparse.from_dense(a_t)
a_t_st_rshp = tf.sparse.reshape(tf.sparse.split(a_t_st,2,axis = 1),[2,2,3])

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-14-3dff37aef5b4> in <module>
----> 1 a_t_st_rshp = tf.sparse.reshape(tf.sparse.split(a_t_st,2,axis = 1),[2,2,3])

/Users/Mine/Python/tf2_4_env/lib/python3.6/site-packages/tensorflow/python/ops/sparse_ops.py in sparse_reshape(sp_input, shape, name)
    886     ValueError:  If `shape` has more than one inferred (== -1) dimension.
    887   """
--> 888   sp_input = _convert_to_sparse_tensor(sp_input)
    889   shape = math_ops.cast(shape, dtype=dtypes.int64)
    890 

/Users/Mine/Python/tf2_4_env/lib/python3.6/site-packages/tensorflow/python/ops/sparse_ops.py in _convert_to_sparse_tensor(sp_input)
     70     return sparse_tensor.SparseTensor.from_value(sp_input)
     71   if not isinstance(sp_input, sparse_tensor.SparseTensor):
---> 72     raise TypeError("Input must be a SparseTensor.")
     73   return sp_input
     74 
 

не могли бы вы помочь мне решить эту проблему?

Комментарии:

1. у меня вообще не было никаких проблем с вашим кодом. Я работаю на python 3.7 и tf 2.4 какую версию вы используете?

2. @CrazyBrazilian вы пробовали запускать a_t_st = tf.sparse.from_dense(a_t) a_t_st_rshp = tf.sparse.reshape(tf.sparse.split(a_t_st,2,axis = 1),[2,2,3]) ? версии: Python 3.6.4 и TensorFlow 2.4.0

3. Похоже, это ошибка в Tensorflow, не могли бы вы поднять эту проблему в Tensorflow github reportository github.com/tensorflow/tensorflow/issues . Спасибо!