创建 AdamParamState 的实例

问题描述 投票:0回答:2

我需要创建一个

AdamParamState
的实例。我查看了
adam.cpp
代码作为示例,并相应地从那里复制了以下代码。但是,使用提供的标头,它仍然无法识别
AdamParamState

我感谢对此事的任何帮助或评论。

#include <torch/optim/adam.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/nn/module.h>
#include <torch/serialize/archive.h>
#include <torch/utils.h>

#include <ATen/ATen.h>

void get_state(torch::optim::Optimizer *optimizer){
    for (auto& group : optimizer->param_groups()) {
        for (auto &p : group.params()) {
            if (!p.grad().defined()) {
                continue;
            }

            auto grad = p.grad();
            TORCH_CHECK(!grad.is_sparse(),
                        "Adam does not support sparse gradients"/*, please consider SparseAdam instead*/);
            ska::flat_hash_map<std::string, std::unique_ptr<torch::optim::OptimizerParamState>>& state_ = optimizer->state();
            auto param_state = state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
            auto tmp_ = p.dim();
            int tmp_0;
            int tmp_1;
            if (tmp_ > 0)
                tmp_0  = p.size(0);
            if (tmp_ > 1)
                tmp_1  = p.size(1);

            std::cout << tmp_ << tmp_0 << tmp_1 << std::endl;
//                auto& options = static_cast<AdamOptions&>(group.options());
            auto& state = static_cast<AdamParamState&>(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);
        }
    }

}
pytorch torch libtorch
2个回答
0
投票

我发现这有效:

auto& state = static_cast<torch::optim::AdamParamState&>(*state_[c10::guts::to_string(p.unsafeGetTensorImpl())]);

非常简单又多汁!


0
投票

我从开源中读取了代码,c10::guts::to_string() 函数出现错误,如下图所示。

enter image description here

你能帮我纠正这个问题吗?

谢谢你

© www.soinside.com 2019 - 2024. All rights reserved.