分类目录:《深入浅出TensorFlow2函数》总目录
函数:
shuffle(buffer_size, seed=None, reshuffle_each_iteration=None, name=None)
该函数可以随机洗牌此数据集的元素。此数据集使用buffer_size的元素填充缓冲区,然后从该缓冲区中随机采样元素,用新元素替换所选元素。为了实现完美的洗牌,需要缓冲区大小大于或等于数据集的完整大小。
例如,如果您的数据集包含10000个元素,但buffer_size设置为1000,则shuffle最初将仅从缓冲区中的前1000个元素中选择一个随机元素。一旦选择一个元素,其在缓冲区中的空间将被下一个(即1001个)元素替换,从而保持1000个元素的缓冲区。而reshuffle_each_iteration控制每次迭代的洗牌顺序是否应该不同。
在TensorFlow2.X中,tf.data.Dataset对象是Python的iterables,所以我们也可以用Python的循环遍历:
dataset = tf.data.Dataset.range(3) dataset = dataset.shuffle(3, reshuffle_each_iteration=True) list(dataset.as_numpy_iterator()) # [1, 0, 2] list(dataset.as_numpy_iterator()) # [1, 2, 0]
参数:
返回值:
函数实现:
def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None, name=None): """Randomly shuffles the elements of this dataset. This dataset fills a buffer with `buffer_size` elements, then randomly samples elements from this buffer, replacing the selected elements with new elements. For perfect shuffling, a buffer size greater than or equal to the full size of the dataset is required. For instance, if your dataset contains 10,000 elements but `buffer_size` is set to 1,000, then `shuffle` will initially select a random element from only the first 1,000 elements in the buffer. once an element is selected, its space in the buffer is replaced by the next (i.e. 1,001-st) element, maintaining the 1,000 element buffer. `reshuffle_each_iteration` controls whether the shuffle order should be different for each epoch. In TF 1.X, the idiomatic way to create epochs was through the `repeat` transformation: ```python dataset = tf.data.Dataset.range(3) dataset = dataset.shuffle(3, reshuffle_each_iteration=True) dataset = dataset.repeat(2) # [1, 0, 2, 1, 2, 0] dataset = tf.data.Dataset.range(3) dataset = dataset.shuffle(3, reshuffle_each_iteration=False) dataset = dataset.repeat(2) # [1, 0, 2, 1, 0, 2] ``` In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it possible to also create epochs through Python iteration: ```python dataset = tf.data.Dataset.range(3) dataset = dataset.shuffle(3, reshuffle_each_iteration=True) list(dataset.as_numpy_iterator()) # [1, 0, 2] list(dataset.as_numpy_iterator()) # [1, 2, 0] ``` ```python dataset = tf.data.Dataset.range(3) dataset = dataset.shuffle(3, reshuffle_each_iteration=False) list(dataset.as_numpy_iterator()) # [1, 0, 2] list(dataset.as_numpy_iterator()) # [1, 0, 2] ``` Args: buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of elements from this dataset from which the new dataset will sample. seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random seed that will be used to create the distribution. See `tf.random.set_seed` for behavior. reshuffle_each_iteration: (Optional.) A boolean, which if true indicates that the dataset should be pseudorandomly reshuffled each time it is iterated over. (Defaults to `True`.) name: (Optional.) A name for the tf.data operation. Returns: Dataset: A `Dataset`. """ return ShuffleDataset( self, buffer_size, seed, reshuffle_each_iteration, name=name)
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)