>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> torch.where(x > 0, x, y)
tensor([[1.4013, 1.0000],
        [1.0000, 0.9267],
        [1.0000, 0.4302]])
>>> x
tensor([[ 1.4013, -0.9960],
        [-0.3715,  0.9267],
        [-0.7163,  0.4302]])

1
2
3
4
5
6
7
8
9
10