在Python中,将二次多项式拟合到p维数据并计算其梯度和Hessian矩阵的最佳方法是什么?

问题描述 投票:0回答:1

我一直在尝试使用scikit-learn库来解决这个问题。 大致:

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

# Make or load an n x p data matrix X and n x 1 array y of the corresponding
# function values.

poly = PolynomialFeatures(degree=2)
Xp = poly.fit_transform(X)
model = LinearRegression()
model.fit(Xp, y)

# Approximate the derivatives of the gradient and Hessian using the relevant
# finite-difference equations and model.predict.

如上所示,

sklearn
做出的设计选择是将多项式回归分离为
PolynomialFeatures
LinearRegression
,而不是将它们组合成单个函数。 这种分离在概念上有优势,但也有一个主要缺点:它有效地阻止了
model
提供方法
gradient
hessian
,如果这样做的话
model
将会更加有用。

我当前的解决方法使用有限差分方程和

model.predict
来近似梯度和 Hessian 的元素(如此处所述)。 但我不喜欢这种方法 - 它对浮点误差很敏感,并且构建梯度所需的“精确”信息和 Hessian 已经包含在
model.coef_
中。

有没有更优雅或更准确的方法来拟合 p 维多项式并在 Python 中找到其梯度和 Hessian 矩阵? 我可以使用使用不同库的方法。

python scikit-learn linear-regression polynomials hessian-matrix
1个回答
0
投票

要计算多项式的梯度或 Hessian,需要知道每个单项式中变量的指数以及相应的单项式系数。第一条信息由

poly.powers_
提供,第二条由
model.coef_
提供:

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
import numpy as np

np.set_printoptions(precision=2, suppress=True)

X = np.arange(6).reshape(3, 2)
y = np.arange(3)
poly = PolynomialFeatures(degree=2)
Xp = poly.fit_transform(X)
model = LinearRegression()
model.fit(Xp, y)

print("Exponents:")
print(poly.powers_.T)
print("Coefficients:")
print(model.coef_)

这给出:

Exponents:
[[0 1 0 2 1 0]
 [0 0 1 0 1 2]]
Coefficients:
[ 0.    0.13  0.13 -0.12 -0.    0.13]

以下函数可用于计算数组给定点处的梯度

x

def gradient(x, powers, coeffs):
    x = np.array(x)
    gp = np.maximum(0, powers[:, np.newaxis] - np.eye(powers.shape[1], dtype=int))
    gp = gp.transpose(1, 2, 0)
    gc = coeffs * powers.T
    return (((x[:, np.newaxis] ** gp).prod(axis=1)) * gc).sum(axis=1)

例如,我们可以用它来计算点

[0, 1]
处的梯度:

print(gradient([0, 1],  poly.powers_, model.coef_))

这给出:

[0.13 0.38]

给定点的 Hessian 矩阵可以用类似的方式计算。

© www.soinside.com 2019 - 2024. All rights reserved.