#python #tensorflow
#python #тензорный поток
Вопрос:
Я читаю набор данных формата TFRecord, состоящий из изображения и его метки, то есть моей целевой переменной. Метка представлена 5 целыми числами в диапазоне от 0 до 4. Функция, которую я использую для чтения набора данных, следующая:
def read_tfrecord(sample):
tfrecord_format = {
"image": tf.io.FixedLenFeature([], tf.string),
"target": tf.io.FixedLenFeature([], tf.int64)
}
sample = tf.io.parse_single_example(sample, tfrecord_format)
image = decode_image(example['image'])
label = tf.cast(example['target'], tf.int32)
return image, tf.reshape(tf.one_hot([label], depth=5, axis=-1), [-1])
Код работает, но я хотел бы внести изменения. Я хотел бы изменить метку следующим образом: 0,1,2,3 на 0 и 4 на 1. Я попытался применить словарь к label, но я не очень хорошо знаю, как обращаться с тензорами.
Ответ №1:
Если вы хотите изменить сам тензор, вы можете использовать tf.map_fn
. Он будет применять функцию к каждому элементу. Я буду использовать постоянные тензоры здесь, чтобы продемонстрировать идею.
fn = lambda x: tf.constant(1) if tf.equal(x,4) else tf.constant(0)
res = tf.map_fn(fn, tf.constant([0,1,2,3,4]))
Когда вы печатаете тензор res
, он показывает of <tf.Tensor: shape=(5,), dtype=int32, numpy=array([0, 0, 0, 0, 1], dtype=int32)>
.