在 Pytorch 中实现 Adam

问题描述 投票:0回答:1

出于学习目的,我正在尝试自己实现 Adam。

这是我的 Adam 实现:

class ADAMOptimizer(Optimizer):
    """
    implements ADAM Algorithm, as a preceding step.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(ADAMOptimizer, self).__init__(params, defaults)
        
    def step(self):
        """
        Performs a single optimization step.
        """
        loss = None
        for group in self.param_groups:
            #print(group.keys())
            #print (self.param_groups[0]['params'][0].size()), First param (W) size: torch.Size([10, 784])
            #print (self.param_groups[0]['params'][1].size()), Second param(b) size: torch.Size([10])
            for p in group['params']:
                grad = p.grad.data
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Momentum (Exponential MA of gradients)
                    state['exp_avg'] = torch.zeros_like(p.data)
                    #print(p.data.size())
                    # RMS Prop componenet. (Exponential MA of squared gradients). Denominator.
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                b1, b2 = group['betas']
                state['step'] += 1
                
                # L2 penalty. Gotta add to Gradient as well.
                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                # Momentum
                exp_avg = torch.mul(exp_avg, b1) + (1 - b1)*grad
                # RMS
                exp_avg_sq = torch.mul(exp_avg_sq, b2) + (1-b2)*(grad*grad)
                
                denom = exp_avg_sq.sqrt() + group['eps']

                bias_correction1 = 1 / (1 - b1 ** state['step'])
                bias_correction2 = 1 / (1 - b2 ** state['step'])
                
                adapted_learning_rate = group['lr'] * bias_correction1 / math.sqrt(bias_correction2)

                p.data = p.data - adapted_learning_rate * exp_avg / denom 
                
                if state['step']  % 10000 ==0:
                    print ("group:", group)
                    print("p: ",p)
                    print("p.data: ", p.data) # W = p.data
                
        return loss

我认为我实现的一切都是正确的,但是与 torch.optim.Adam 相比,我实现的损失图非常尖。

我的 ADAM 实现损失图(如下) enter image description here

torch.optim.Adam 损失图(下) enter image description here 如果有人能告诉我我做错了什么,我将非常感激。

完整代码,包括数据、图表(超级容易运行):https://github.com/byorxyz/AMS_pytorch/blob/master/AdamFails_1dConvex.ipynb

neural-network mathematical-optimization pytorch gradient-descent
1个回答
0
投票

看起来上面的代码没有存储

exp_avg
exp_avg_sq
的更新状态。另一个小细节是分母的偏差校正也适用于 epsilon。最后,Adam 默认使用
beta2=0.999

以下版本只需进行最少的更改就可以解决问题:

class ADAMOptimizer(torch.optim.Optimizer):
    """
    implements ADAM Algorithm, as a preceding step.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(ADAMOptimizer, self).__init__(params, defaults)

    def step(self):
        """
        Perform a single optimization step.
        """
        loss = None
        for group in self.param_groups:

            for p in group['params']:
                grad = p.grad.data
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Momentum (Exponential MA of gradients)
                    state['exp_avg'] = torch.zeros_like(p.data)

                    # RMS Prop componenet. (Exponential MA of squared gradients). Denominator.
                    state['exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                b1, b2 = group['betas']
                state['step'] += 1

                # Add weight decay if any
                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                # Momentum
                exp_avg = torch.mul(exp_avg, b1) + (1 - b1)*grad
                
                # RMS
                exp_avg_sq = torch.mul(exp_avg_sq, b2) + (1-b2)*(grad*grad)

                mhat = exp_avg / (1 - b1 ** state['step'])
                vhat = exp_avg_sq / (1 - b2 ** state['step'])
                
                denom = torch.sqrt( vhat + group['eps'] )

                p.data = p.data - group['lr'] * mhat / denom 
                
                # Save state
                state['exp_avg'], state['exp_avg_sq'] = exp_avg, exp_avg_sq 

        return loss
© www.soinside.com 2019 - 2024. All rights reserved.