- 前言
- 1.Tensor.masked_fill_(mask, value)
- 举个例子
- 2.torch.masked_select(input, mask, *, out=None) → Tensor
- 举个例子
- 3.Tensor.masked_scatter_(mask, source)
- 举个例子
前言
mask是深度学习里面常用的 *** 作,最近在研究transformer的pytorch代码,总能看到各种mask的命令,在这里总结一下
Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.
-
Parameters
- mask (BoolTensor) – the boolean mask
- value (float) – the value to fill in with
import torch mask = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]).bool() # tensor([[ True, False, False], # [False, True, False], # [False, False, True]]) a = torch.randn(3,3) a.masked_fill(mask, 0) # tensor([[ 0.0000, 0.6781, 0.6532], # [-1.2078, 0.0000, 0.4964], # [ 0.2192, -0.6276, 0.0000]]) a.masked_fill(~mask, 0)#可以对mask取反 # tensor([[-0.4438, 0.0000, 0.0000], # [ 0.0000, 1.3907, 0.0000], # [ 0.0000, 0.0000, 2.2462]])2.torch.masked_select(input, mask, *, out=None) → Tensor
Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor.
The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.
(注意)The returned tensor does not use the same storage as the original tensor
-
Parameters
- input (Tensor) – the input tensor.
- mask (BoolTensor) – the tensor containing the binary mask to index with
import torch x = torch.randn(3,4) # tensor([[ 0.2914, -0.1056, 0.4946, 0.2926], # [-1.0920, -0.2156, 3.0989, -0.9067], # [-0.1522, 1.9527, 0.1660, 0.8310]]) mask = x > 0.5 # tensor([[ 0.2914, -0.1056, 0.4946, 0.2926], # [-1.0920, -0.2156, 3.0989, -0.9067], # [-0.1522, 1.9527, 0.1660, 0.8310]]) torch.masked_select(x, mask) # tensor([3.0989, 1.9527, 0.8310])3.Tensor.masked_scatter_(mask, source)
Tensor.masked_scatter_(mask, source)
Copies elements from source into self tensor at positions where the mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. The source should have at least as many elements as the number of ones in mask
source大小和mask至少一样,能够被广播到Tensor上,或者source和Tensor一样
作用就是把source里mask是true的位置挑出来给Tensor
-
Parameters
- mask (BoolTensor) – the boolean mask
- source (Tensor) – the tensor to copy from
import torch mask = torch.BoolTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) # tensor([[ True, False, False], # [False, True, False], # [False, False, True]]) a = torch.randn(2,3,3) s = torch.ones_like(a) a.masked_scatter(mask, s) # tensor([[[ 1.0000, -0.1560, -0.7760], # [-0.5192, 1.0000, -0.1709], # [ 0.2091, 0.5650, 1.0000]], # [[ 1.0000, 0.0623, -0.1447], # [-1.2910, 1.0000, -1.2722], # [-0.7864, -0.1118, 1.0000]]])
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)