Tensorflow 相当于 torch.scatter_add

问题描述 投票:0回答:1

如何与tf 1.15实现相同的操作?

import torch

B, T, N, K = 2,3,4,2

# a is a counter table where T is the number of groups
a = torch.zeros(T, N, dtype=torch.long)

# x is a batch of data where K elements are selected for each group
x = torch.randint(0, N, (B, T, K))

# the counter should record all data within the batch (per group)
y = x.permute(1,2,0).reshape(T, -1)

# update counter
a.scatter_add(index=y, dim=-1, src=torch.ones_like(y))
tensorflow pytorch scatter-plot
1个回答
0
投票

正如@mhenning 所提到的,

对于 Tensorflow 版本 1.15,您可以使用 tf.scatter_add()。但是,这些功能 -

  • tf.scatter_add(),
  • tf.scatter_div(),
  • tf.scatter_min(),
  • tf.scatter_max(),
  • tf.scatter_mul() 在最新的 Tensorflow 版本 2.x 中已弃用,仅适用于 Tensorflow 1.x。
© www.soinside.com 2019 - 2024. All rights reserved.