深入浅出TensorFlow2函数——tf.data.Dataset.shuffle

深入浅出TensorFlow2函数——tf.data.Dataset.shuffle,第1张

深入浅出TensorFlow2函数——tf.data.Dataset.shuffle

分类目录:《深入浅出TensorFlow2函数》总目录


函数:

shuffle(buffer_size, seed=None, reshuffle_each_iteration=None, name=None)

该函数可以随机洗牌此数据集的元素。此数据集使用buffer_size的元素填充缓冲区,然后从该缓冲区中随机采样元素,用新元素替换所选元素。为了实现完美的洗牌,需要缓冲区大小大于或等于数据集的完整大小。

例如,如果您的数据集包含10000个元素,但buffer_size设置为1000,则shuffle最初将仅从缓冲区中的前1000个元素中选择一个随机元素。一旦选择一个元素,其在缓冲区中的空间将被下一个(即1001个)元素替换,从而保持1000个元素的缓冲区。而reshuffle_each_iteration控制每次迭代的洗牌顺序是否应该不同。

在TensorFlow2.X中,tf.data.Dataset对象是Python的iterables,所以我们也可以用Python的循环遍历:

dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 2, 0]

参数:

参数意义buffer_size[tf.int64 /tf.Tensor]表示新数据集将从此数据集中采样的元素数。seed[可选,tf.int64 /tf.Tensor]表示将用于创建分布的随机种子。reshuffle_each_iteration[可选,tf.bool]如果为True,则表示每次迭代数据集时都应伪随机地重新洗牌,默认为True。name[可选]tf.data *** 作的名称

返回值:

返回值意义Dataset一个tf.data.Dataset的数据集。

函数实现:

  def shuffle(self,
              buffer_size,
              seed=None,
              reshuffle_each_iteration=None,
              name=None):
    """Randomly shuffles the elements of this dataset.
    This dataset fills a buffer with `buffer_size` elements, then randomly
    samples elements from this buffer, replacing the selected elements with new
    elements. For perfect shuffling, a buffer size greater than or equal to the
    full size of the dataset is required.
    For instance, if your dataset contains 10,000 elements but `buffer_size` is
    set to 1,000, then `shuffle` will initially select a random element from
    only the first 1,000 elements in the buffer. once an element is selected,
    its space in the buffer is replaced by the next (i.e. 1,001-st) element,
    maintaining the 1,000 element buffer.
    `reshuffle_each_iteration` controls whether the shuffle order should be
    different for each epoch. In TF 1.X, the idiomatic way to create epochs
    was through the `repeat` transformation:
    ```python
    dataset = tf.data.Dataset.range(3)
    dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
    dataset = dataset.repeat(2)
    # [1, 0, 2, 1, 2, 0]
    dataset = tf.data.Dataset.range(3)
    dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
    dataset = dataset.repeat(2)
    # [1, 0, 2, 1, 0, 2]
    ```
    In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it
    possible to also create epochs through Python iteration:
    ```python
    dataset = tf.data.Dataset.range(3)
    dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
    list(dataset.as_numpy_iterator())
    # [1, 0, 2]
    list(dataset.as_numpy_iterator())
    # [1, 2, 0]
    ```
    ```python
    dataset = tf.data.Dataset.range(3)
    dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
    list(dataset.as_numpy_iterator())
    # [1, 0, 2]
    list(dataset.as_numpy_iterator())
    # [1, 0, 2]
    ```
    Args:
      buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
        elements from this dataset from which the new dataset will sample.
      seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
        seed that will be used to create the distribution. See
        `tf.random.set_seed` for behavior.
      reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
        that the dataset should be pseudorandomly reshuffled each time it is
        iterated over. (Defaults to `True`.)
      name: (Optional.) A name for the tf.data operation.
    Returns:
      Dataset: A `Dataset`.
    """
    return ShuffleDataset(
        self, buffer_size, seed, reshuffle_each_iteration, name=name)

欢迎分享,转载请注明来源:内存溢出

原文地址: http://www.outofmemory.cn/zaji/5658707.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存