寻找3D线性回归线的斜率

问题描述 投票:0回答:1

我正在尝试在3D空间中找到直线的斜率。在this post

中给出绘制这种线的解决方案

这是上面链接中的给定代码:

import numpy as np

pts = np.add.accumulate(np.random.random((10,3)))
x,y,z = pts.T

# this will find the slope and x-intercept of a plane
# parallel to the y-axis that best fits the data
A_xz = np.vstack((x, np.ones(len(x)))).T
m_xz, c_xz = np.linalg.lstsq(A_xz, z)[0]

# again for a plane parallel to the x-axis
A_yz = np.vstack((y, np.ones(len(y)))).T
m_yz, c_yz = np.linalg.lstsq(A_yz, z)[0]

# the intersection of those two planes and
# the function for the line would be:
# z = m_yz * y + c_yz
# z = m_xz * x + c_xz
# or:
def lin(z):
    x = (z - c_xz)/m_xz
    y = (z - c_yz)/m_yz
    return x,y

#verifying:
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure()
ax = Axes3D(fig)
zz = np.linspace(0,5)
xx,yy = lin(zz)
ax.scatter(x, y, z)
ax.plot(xx,yy,zz)
plt.savefig('test.png')
plt.show()

在数学上,我知道如何找到两个平面的交点以及给定线的斜率,但是我很难将其放入代码中。

如何使用该解决方案找到所得回归线的斜率?

python numpy matplotlib regression linear-regression
1个回答
0
投票

3D线的“斜率​​”通常被视为“投影”到x,y和z平面上的线的斜率。参见this question的第二个答案>

如果这是您想要的,那么计算这些就很容易了;您下面的代码的修改后的版本将其放入sxsysz变量中:

import numpy as np
import matplotlib.pyplot as plt
from   mpl_toolkits.mplot3d import Axes3D
from   math import pow, sqrt

pts = np.add.accumulate(np.random.random((10,3)))

x, y, z = pts.T

# plane parallel to the y-axis
A_xz = np.vstack((x, np.ones(len(x)))).T
m_xz, c_xz = np.linalg.lstsq(A_xz, z, rcond=None)[0]

# plane parallel to the x-axis
A_yz = np.vstack((y, np.ones(len(y)))).T
m_yz, c_yz = np.linalg.lstsq(A_yz, z, rcond=None)[0]

# the intersection of those two planes and
# the function for the line would be:
# z = m_yz * y + c_yz
# z = m_xz * x + c_xz
# or:
def lin(z):
    x = (z - c_xz)/m_xz
    y = (z - c_yz)/m_yz
    return x,y


# get 2 points on the intersection line 
za = z[0]
zb = z[len(z) - 1]
xa, ya = lin(za)
xb, yb = lin(zb)

# get distance between points
len = sqrt(pow(xb - xa, 2) + pow(yb - ya, 2) + pow(zb - za, 2))

# get slopes (projections onto x, y and z planes)
sx = (xb - xa) / len  # x slope
sy = (yb - ya) / len  # y slope
sz = (zb - za) / len  # z slope

# integrity check - the sum of squares of slopes should equal 1.0
# print (pow(sx, 2) + pow(sy, 2) + pow(sz, 2))

fig = plt.figure()
ax = Axes3D(fig)
ax.set_xlabel("x, slope: %.4f" %sx, color='blue')
ax.set_ylabel("y, slope: %.4f" %sy, color='blue')
ax.set_zlabel("z, slope: %.4f" %sz, color='blue')
ax.scatter(x, y, z)
ax.plot([xa], [ya], [za], markerfacecolor='k', markeredgecolor='k', marker = 'o')
ax.plot([xb], [yb], [zb], markerfacecolor='k', markeredgecolor='k', marker = 'o')
ax.plot([xa, xb], [ya, yb], [za, zb], color = 'r')

plt.show()

下面的输出图显示了所讨论的线,它刚好在2个极限xyz点之间绘制。

enter image description here希望对您有所帮助

© www.soinside.com 2019 - 2024. All rights reserved.