按照指定轴上的坐标进行过滤

PHOTO EMBED

Sat Oct 08 2022 13:30:50 GMT+0000 (Coordinated Universal Time)

Saved by @linzao

>>> x = torch.randn(3, 4)		# 目标矩阵
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-0.4664,  0.2647, -0.1228, -1.1068],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])	# 在轴上筛选坐标
>>> torch.index_select(x, dim=0, indices)	# 指定筛选对象、轴、筛选坐标
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],
        [-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, dim=1, indices)
tensor([[ 0.1427, -0.5414],
        [-0.4664, -0.1228],
        [-1.1734,  0.7230]])

1
2
3
4
5
6
7
8
9
10
11
12
13
content_copyCOPY

https://blog.csdn.net/tfcy694/article/details/85332953