我有一个机器学习模型,其中模型参数的梯度是解析的,不需要自动区分。但是,我仍然希望能够在Flux中利用不同的优化器,而不必依赖Zygote进行区分。这是我的代码片段。
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)
θ = Flux.Params([b, c, U, W])
opt = ADAM(0.01)
然后我有了一个计算模型参数θ
的解析梯度的函数。
function gradients(x) # x = one input data point or a batch of input data points
# stuff to calculate gradients of each parameter
# returns gradients of each parameter
然后我希望能够执行以下操作。
grads = gradients(x)
update!(opt, θ, grads)
我的问题是:gradient(x)
函数需要返回哪种形式/类型才能执行update!(opt, θ, grads)
,我该怎么做?
如果不使用Params
,则只需将grads
作为渐变即可。唯一的要求是θ
和grads
的大小相同。
例如,map((x, g) -> update!(opt, x, g), θ, grads)
其中θ == [b, c, U, W]
和grads = [gradients(b), gradients(c), gradients(U), gradients(W)]
(不确定gradients
期望为您提供的输入)。
更新:但是要回答您的原始问题,gradients
需要返回在此处找到的Grads
对象:https://github.com/FluxML/Zygote.jl/blob/359e586766129878ca0e56121037ed80afda6289/src/compiler/interface.jl#L88
所以类似
# within gradient function body assuming gb is the gradient w.r.t b
g = Zygote.Grads(IdDict())
g.grads[θ[1]] = gb # assuming θ[1] == b
但不使用Params
可能更易于调试。唯一的问题是,没有update!
可以在一系列参数上工作,但是您可以轻松定义自己的参数:
function Flux.Optimise.update!(opt, xs::Tuple, gs)
for (x, g) in zip(xs, gs)
update!(opt, x, g)
end
end
# use it like this
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)
θ = (b, c, U, W)
opt = ADAM(0.01)
x = # generate input to gradients
grads = gradients(x) # return tuple (gb, gc, gU, gW)
update!(opt, θ, grads)
更新2:
[另一种选择是仍然使用Zygote进行渐变,以便它自动为您设置Grads
对象,但是使用自定义伴随物,以便它使用您的分析函数来计算伴随物。假设您的ML模型定义为称为f
的函数,因此f(x)
返回输入x
的模型输出。我们还假设gradients(x)
返回解析梯度w.r.t。 x
就像您在问题中提到的那样。然后,下面的代码仍将使用Zygote的AD,它将正确填充Grads
对象,但是它将使用您为函数f
计算梯度的定义:
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)
θ = Flux.Params([b, c, U, W])
f(x) = # define your model
gradients(x) = # define your analytical gradient
# set up the custom adjoint
Zygote.@adjoint f(x) = f(x), Δ -> (gradients(x),)
opt = ADAM(0.01)
x = # generate input to model
y = # output of model
grads = Zygote.gradient(() -> Flux.mse(f(x), y), θ)
update!(opt, θ, grads)
注意,以上我以Flux.mse
作为损失示例。这种方法的缺点是Zygote的gradient
函数需要标量输出。如果您的模型因某种损失而将输出标量误差值,那么@adjoint
是最好的方法。这将适合您正在执行标准ML的情况,唯一的变化是希望Zygote使用您的函数来解析计算f
的梯度。
如果您正在做更复杂的事情且无法使用Zygote.gradient
,则第一种方法(不使用Params
)是最合适的。 Params
实际上仅是为了与Flux的旧AD向后兼容,因此,如果可能的话,最好避免使用它。