如何与tf 1.15实现相同的操作?
import torch
B, T, N, K = 2,3,4,2
# a is a counter table where T is the number of groups
a = torch.zeros(T, N, dtype=torch.long)
# x is a batch of data where K elements are selected for each group
x = torch.randint(0, N, (B, T, K))
# the counter should record all data within the batch (per group)
y = x.permute(1,2,0).reshape(T, -1)
# update counter
a.scatter_add(index=y, dim=-1, src=torch.ones_like(y))
正如@mhenning 所提到的,
对于 Tensorflow 版本 1.15,您可以使用 tf.scatter_add()。但是,这些功能 -