带有索引列表的割炬张量

问题描述 投票:1回答:1

我正在做一个强化学习项目,我试图获得一个张量,该张量表示所有给定动作的预期收益。我对batch大小为0或1(两个潜在动作)的值选择的动作有一个长张量。对于大小为batch * action_size的每个动作,我都有一个预期的张量,并且我希望大小为batch的张量。

例如,如果批次大小为4,那么我有

action = tensor([1,0,0,1])
expectedReward = tensor([[3,7],[5,9],[-1,12],[0,1]])

我想要的是

rewardForActions = tensor([7,5,-1,1])

我以为this会回答我的问题,但是根本不一样,因为如果我采用该解决方案,它将以4 * 4张量结束,从每行中选择4次,而不是一次。

有什么想法吗?

python pytorch slice
1个回答
1
投票

您可以做

rewardForActions = expectedReward.index_select(1, action).diagonal()  
# tensor([ 7,  5, -1,  1])                                                                                                                                                                                                            
© www.soinside.com 2019 - 2024. All rights reserved.