Сохранение состояния перемешанного набора данных

#tensorflow #tensorflow-datasets #tensorflow2.x

#tensorflow #tensorflow-наборы данных #tensorflow2.x

Вопрос:

Я ищу механизм для сохранения случайного состояния, используемого tf.data.Dataset.shuffle . Для контекста я хочу иметь возможность воспроизводить результаты обучения при перезапусках.

У меня есть решение (приведенное ниже), но оно не особенно элегантно, и я уверен batch , что / unbatch приведет к проблемам с производительностью. Есть ли эквивалентный способ сделать это с помощью Dataset.shuffle ?

 import tensorflow as tf
import numpy as np


class Shuffler(tf.Module):
    def __init__(self, buffer_size: int, seed: int = 0):
        self._buffer_size = buffer_size
        self._seed = seed
        self._rng = tf.random.Generator.from_seed(seed)

    def __call__(self, dataset: tf.data.Dataset):
        def map_fn(*args):
            vals = self._rng.uniform((self._buffer_size,))
            i = tf.argsort(vals)
            if len(args) == 1:
                (args,) = args
            return tf.nest.map_structure(lambda x: tf.gather(x, i), args)

        return dataset.batch(self._buffer_size).map(map_fn).unbatch()


def as_list(ds: tf.data.Dataset):
    return [x.numpy() for x in ds]


shuffler = Shuffler(5)
chkpt = tf.train.Checkpoint(shuffler=shuffler)
p0 = chkpt.save("/tmp/chkpt-0")
ds = tf.data.Dataset.range(5).apply(shuffler)
expected0 = as_list(ds)
p1 = chkpt.save("/tmp/chkpt-1")
expected1 = as_list(ds)
# ensure they're actually shuffled
assert not np.all(expected0 == expected1)
assert set(expected0) == set(expected1)

chkpt.restore(p0)
np.testing.assert_equal(as_list(ds), expected0)

np.testing.assert_equal(as_list(ds), expected1)
# mangle state by iterating over it again
as_list(ds)

# restore p1
chkpt.restore(p1)
np.testing.assert_equal(as_list(ds), expected1)
print("Passed!")
  

Ответ №1:

Оказывается, состояние уже управляется в итераторе.

 import tensorflow as tf
import numpy as np


def as_list(it: tf.data.Iterator, length: int = 5):
    return [it.next().numpy() for _ in range(length)]


ds = tf.data.Dataset.range(5).shuffle(5, seed=0).repeat()
it = iter(ds)
chkpt = tf.train.Checkpoint(it=it)
p0 = chkpt.save("/tmp/chkpt-0")
expected0 = as_list(it)
p1 = chkpt.save("/tmp/chkpt-1")
expected1 = as_list(it)
# ensure they're actually shuffled
assert not np.all(expected0 == expected1)
assert set(expected0) == set(expected1)

chkpt.restore(p0)
np.testing.assert_equal(as_list(it), expected0)

np.testing.assert_equal(as_list(it), expected1)
# mangle state by iterating over it again
as_list(it)

# restore p1
chkpt.restore(p1)
np.testing.assert_equal(as_list(it), expected1)
print("Passed!")