pytorch常用mask命令

pytorch常用mask命令,第1张

pytorch常用mask命令

文章目录
  • 前言
  • 1.Tensor.masked_fill_(mask, value)
    • 举个例子
  • 2.torch.masked_select(input, mask, *, out=None) → Tensor
    • 举个例子
  • 3.Tensor.masked_scatter_(mask, source)
    • 举个例子


前言


mask是深度学习里面常用的 *** 作,最近在研究transformer的pytorch代码,总能看到各种mask的命令,在这里总结一下

1.Tensor.masked_fill_(mask, value)

Fills elements of self tensor with value where mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor.

Parameters
mask (BoolTensor) – the boolean mask
value (float) – the value to fill in with
举个例子
import torch
mask = torch.tensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]]).bool()
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(3,3)
a.masked_fill(mask, 0)
# tensor([[ 0.0000,  0.6781,  0.6532],
#         [-1.2078,  0.0000,  0.4964],
#         [ 0.2192, -0.6276,  0.0000]])
a.masked_fill(~mask, 0)#可以对mask取反
# tensor([[-0.4438,  0.0000,  0.0000],
#         [ 0.0000,  1.3907,  0.0000],
#         [ 0.0000,  0.0000,  2.2462]])
2.torch.masked_select(input, mask, *, out=None) → Tensor

Returns a new 1-D tensor which indexes the input tensor according to the boolean mask mask which is a BoolTensor.
The shapes of the mask tensor and the input tensor don’t need to match, but they must be broadcastable.

(注意)The returned tensor does not use the same storage as the original tensor

Parameters
input (Tensor) – the input tensor.
mask (BoolTensor) – the tensor containing the binary mask to index with
举个例子
import torch
x = torch.randn(3,4)
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
mask = x > 0.5
# tensor([[ 0.2914, -0.1056,  0.4946,  0.2926],
#         [-1.0920, -0.2156,  3.0989, -0.9067],
#         [-0.1522,  1.9527,  0.1660,  0.8310]])
torch.masked_select(x, mask)
# tensor([3.0989, 1.9527, 0.8310])
3.Tensor.masked_scatter_(mask, source)

Tensor.masked_scatter_(mask, source)
Copies elements from source into self tensor at positions where the mask is True. The shape of mask must be broadcastable with the shape of the underlying tensor. The source should have at least as many elements as the number of ones in mask

source大小和mask至少一样,能够被广播到Tensor上,或者source和Tensor一样
作用就是把source里mask是true的位置挑出来给Tensor

Parameters
mask (BoolTensor) – the boolean mask
source (Tensor) – the tensor to copy from
举个例子
import torch
mask = torch.BoolTensor([[1, 0, 0], [0, 1, 0],  [0, 0, 1]])
# tensor([[ True, False, False],
#         [False,  True, False],
#         [False, False,  True]])
a = torch.randn(2,3,3)
s = torch.ones_like(a)
a.masked_scatter(mask, s)
# tensor([[[ 1.0000, -0.1560, -0.7760],
#          [-0.5192,  1.0000, -0.1709],
#          [ 0.2091,  0.5650,  1.0000]],

#         [[ 1.0000,  0.0623, -0.1447],
#          [-1.2910,  1.0000, -1.2722],
#          [-0.7864, -0.1118,  1.0000]]])

欢迎分享,转载请注明来源:内存溢出

原文地址: https://www.outofmemory.cn/zaji/5496007.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-12
下一篇 2022-12-12

发表评论

登录后才能评论

评论列表(0条)

保存