如何使用 GPU 和 PyTorch 计算二阶导数

问题描述 投票:0回答:1

我有一个与深度强化学习算法相关的Python代码段,它使用Hessian矩阵和fisher信息矩阵计算二阶优化和二阶导数。通常我在 GPU (cuda) 上运行整个代码,但由于我遇到了计算 cuda 中二阶导数的计算问题,

NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented. Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: 
with torch.backends.cudnn.flags(enabled=False):
    output = model(inputs)

我必须将这段代码转移到CPU,现在代码是顺序执行的,而不是并行执行的,这需要很长时间才能运行:

grads = torch.autograd.grad(policy_loss, self.policy.Actor.parameters(), retain_graph=True)
loss_grad = torch.cat([grad.view(-1) for grad in grads])

def Fvp_fim(v = -loss_grad):
    with torch.backends.cudnn.flags(enabled=False):
        M, mu, info = self.policy.Actor.get_fim(states_batch)
        #pdb.set_trace()
        mu = mu.view(-1)
        filter_input_ids = set([info['std_id']])

        t = torch.ones(mu.size(), requires_grad=True, device=mu.device)
        mu_t = (mu * t).sum()
        Jt = compute_flat_grad(mu_t, self.policy.Actor.parameters(), filter_input_ids=filter_input_ids, create_graph=True)
        Jtv = (Jt * v).sum()
        Jv = torch.autograd.grad(Jtv, t)[0]
        MJv = M * Jv.detach()
        mu_MJv = (MJv * mu).sum()
        JTMJv = compute_flat_grad(mu_MJv, self.policy.Actor.parameters(), filter_input_ids=filter_input_ids, create_graph=True).detach()
        JTMJv /= states_batch.shape[0]
        std_index = info['std_index']
        JTMJv[std_index: std_index + M.shape[0]] += 2 * v[std_index: std_index + M.shape[0]]
        return JTMJv + v * self.damping

上面是主函数,计算二阶导数。下面是它所使用的支持函数和相关类。

def compute_flat_grad(output, inputs, filter_input_ids=set(), retain_graph=True, create_graph=False):
    if create_graph:
        retain_graph = True

    inputs = list(inputs)
    params = []
    for i, param in enumerate(inputs):
        if i not in filter_input_ids:
            params.append(param)

    grads = torch.autograd.grad(output, params, retain_graph=retain_graph, create_graph=create_graph, allow_unused=True)

    j = 0
    out_grads = []
    for i, param in enumerate(inputs):
        if (i in filter_input_ids):
            out_grads.append(torch.zeros(param.view(-1).shape, device=param.device, dtype=param.dtype))
        else:
            if (grads[j] == None):
                out_grads.append(torch.zeros(param.view(-1).shape, device=param.device, dtype=param.dtype))
            else:
                out_grads.append(grads[j].view(-1))
            j += 1
    grads = torch.cat(out_grads)

    for param in params:
        param.grad = None
    return grads

------

import torch
import torch.nn as nn


from agents.models.feature_extracter import LSTMFeatureExtractor
from agents.models.policy import PolicyModule
from agents.models.value import ValueModule


class ActorNetwork(nn.Module):
    def __init__(self, args):
        super(ActorNetwork, self).__init__()
        self.FeatureExtractor = LSTMFeatureExtractor(args)
        self.PolicyModule = PolicyModule(args)

    def forward(self, s):
        lstmOut = self.FeatureExtractor.forward(s)
        mu, sigma, action, log_prob = self.PolicyModule.forward(lstmOut)
        return mu, sigma, action, log_prob
    
    def get_fim(self, x):
        mu, sigma, _, _ = self.forward(x)

        if sigma.dim() == 1:
            sigma = sigma.unsqueeze(0)

        cov_inv = sigma.pow(-2).repeat(x.size(0), 1)

        param_count = 0
        std_index = 0
        id = 0
        std_id = id
        for name, param in self.named_parameters():
            if name == "sigma.weight":
                std_id = id
                std_index = param_count
            param_count += param.view(-1).shape[0]
            id += 1

        return cov_inv.detach(), mu, {'std_id': std_id, 'std_index': std_index}

从更大的角度来看,有大量批次要执行此函数,因为所有批次都必须按顺序执行此函数,因此会大大增加总运行时间。在 cuda/GPU 上运行时,是否有可能使用 Pytorch 计算二阶导数?

python pytorch gpu reinforcement-learning cudnn
1个回答
0
投票

您可以使用 torchrl 中的纯 Python RNN,这是其中之一 https://github.com/pytorch/rl/blob/df4fa7808e81f4b95ee2c22a4bb768370a669048/torchrl/modules/tensordict_module/rnn.py#L27

它们与 vmap、高阶 diff 和 torch.compile 一起使用。

© www.soinside.com 2019 - 2024. All rights reserved.