#tensorflow #tensorflow-datasets #tensorflow2.x
#tensorflow #tensorflow-наборы данных #tensorflow2.x
Вопрос:
Я ищу механизм для сохранения случайного состояния, используемого tf.data.Dataset.shuffle
. Для контекста я хочу иметь возможность воспроизводить результаты обучения при перезапусках.
У меня есть решение (приведенное ниже), но оно не особенно элегантно, и я уверен batch
, что / unbatch
приведет к проблемам с производительностью. Есть ли эквивалентный способ сделать это с помощью Dataset.shuffle
?
import tensorflow as tf
import numpy as np
class Shuffler(tf.Module):
def __init__(self, buffer_size: int, seed: int = 0):
self._buffer_size = buffer_size
self._seed = seed
self._rng = tf.random.Generator.from_seed(seed)
def __call__(self, dataset: tf.data.Dataset):
def map_fn(*args):
vals = self._rng.uniform((self._buffer_size,))
i = tf.argsort(vals)
if len(args) == 1:
(args,) = args
return tf.nest.map_structure(lambda x: tf.gather(x, i), args)
return dataset.batch(self._buffer_size).map(map_fn).unbatch()
def as_list(ds: tf.data.Dataset):
return [x.numpy() for x in ds]
shuffler = Shuffler(5)
chkpt = tf.train.Checkpoint(shuffler=shuffler)
p0 = chkpt.save("/tmp/chkpt-0")
ds = tf.data.Dataset.range(5).apply(shuffler)
expected0 = as_list(ds)
p1 = chkpt.save("/tmp/chkpt-1")
expected1 = as_list(ds)
# ensure they're actually shuffled
assert not np.all(expected0 == expected1)
assert set(expected0) == set(expected1)
chkpt.restore(p0)
np.testing.assert_equal(as_list(ds), expected0)
np.testing.assert_equal(as_list(ds), expected1)
# mangle state by iterating over it again
as_list(ds)
# restore p1
chkpt.restore(p1)
np.testing.assert_equal(as_list(ds), expected1)
print("Passed!")
Ответ №1:
Оказывается, состояние уже управляется в итераторе.
import tensorflow as tf
import numpy as np
def as_list(it: tf.data.Iterator, length: int = 5):
return [it.next().numpy() for _ in range(length)]
ds = tf.data.Dataset.range(5).shuffle(5, seed=0).repeat()
it = iter(ds)
chkpt = tf.train.Checkpoint(it=it)
p0 = chkpt.save("/tmp/chkpt-0")
expected0 = as_list(it)
p1 = chkpt.save("/tmp/chkpt-1")
expected1 = as_list(it)
# ensure they're actually shuffled
assert not np.all(expected0 == expected1)
assert set(expected0) == set(expected1)
chkpt.restore(p0)
np.testing.assert_equal(as_list(it), expected0)
np.testing.assert_equal(as_list(it), expected1)
# mangle state by iterating over it again
as_list(it)
# restore p1
chkpt.restore(p1)
np.testing.assert_equal(as_list(it), expected1)
print("Passed!")