torch.where()

PHOTO EMBED

Sat Oct 08 2022 13:31:16 GMT+0000 (Coordinated Universal Time)

Saved by @linzao

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> torch.where(x > 0, x, y)
tensor([[1.4013, 1.0000],
        [1.0000, 0.9267],
        [1.0000, 0.4302]])
>>> x
tensor([[ 1.4013, -0.9960],
        [-0.3715,  0.9267],
        [-0.7163,  0.4302]])

1
2
3
4
5
6
7
8
9
10
content_copyCOPY

https://blog.csdn.net/tfcy694/article/details/85332953