我需要创建一个
AdamParamState
的实例。我查看了 adam.cpp
代码作为示例,并相应地从那里复制了以下代码。但是,使用提供的标头,它仍然无法识别AdamParamState
。
我感谢对此事的任何帮助或评论。
#include <torch/optim/adam.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/nn/module.h>
#include <torch/serialize/archive.h>
#include <torch/utils.h>
#include <ATen/ATen.h>
void get_state(torch::optim::Optimizer *optimizer){
for (auto& group : optimizer->param_groups()) {
for (auto &p : group.params()) {
if (!p.grad().defined()) {
continue;
}
auto grad = p.grad();
TORCH_CHECK(!grad.is_sparse(),
"Adam does not support sparse gradients"/*, please consider SparseAdam instead*/);
ska::flat_hash_map<std::string, std::unique_ptr<torch::optim::OptimizerParamState>>& state_ = optimizer->state();
auto param_state = state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
auto tmp_ = p.dim();
int tmp_0;
int tmp_1;
if (tmp_ > 0)
tmp_0 = p.size(0);
if (tmp_ > 1)
tmp_1 = p.size(1);
std::cout << tmp_ << tmp_0 << tmp_1 << std::endl;
// auto& options = static_cast<AdamOptions&>(group.options());
auto& state = static_cast<AdamParamState&>(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
}
}
}