我看不到线穿过我在图上的点 使用带有 python 3.7 和 numpy 的 conda 环境,
但是误差很大
数据集:
错误:
图片
#representar puntos
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
#dataset.descripve()
data=pd.read_csv('Salary_Data.csv')
Xs=data[['YearsExperience']]
X=np.array(Xs)
y_ds=data['Salary'].values
data.describe()
def h(x,w):
return w[0]+w[1]*x
w=np.random.rand(2)#esta podria ser la causa de que no converga
#w[0]=3
#w[1]=5
y_prediccion=[h(i,w) for i in X]
#otra forma
#error
def error(x,y,w):
suma=0
for i in range(len(y)):
suma=suma+(y[i]-h(x[i],w))**2
return suma/(2*len(y))
print("error mse"+str( error(y_ds,X,w)))
plt.plot(X,y_ds,'*')#nube de puntos
plt.plot(X,y_prediccion)#prediccion
plt.show()
#derivadas
def grad(x,y,w):
sum1=0
sum2=0
for i in range (len(y)):
sum1 = sum1 + (y[i]-h(x[i],w))*(-1)
sum2 = sum2 + (y[i]-h(x[i],w))*(-x[i])
gra_w0=(sum1)
gra_w1=(sum2)
return gra_w0/len(y),gra_w1/len(y)
#creamos un algoritmo de aprendisaje
def train(X,y_ds,w,epochs,alpha):
list_error=[]
time=[]
for i in range(epochs):
err=error(y_ds,X,w)# (calculo de error
list_error.append(err)
time.append(i)
print(err)#imprimr para ver si esta disminuyendo en el tiempo
grad_w0,grad_w1=grad(X,y_ds,w)
#grad_w1=grad(y_ds,x_ds,w) #linea ala que se debia de modificar para solucionar el error(can't multiply sequence by non-int of type 'float') ,no es lomismo que ponerlo todo en una sola linea mirar arriba
w[0]=w[0]-alpha*grad_w0
w[1]=w[1]-alpha*grad_w1
#calculo grafica error error
#plt.xlabel("time")
#plt.ylabel("error")
#plt.plot(time,list_error,"*")
#calculo de la recta real
y_r=[h(i,w)for i in X]
plt.plot(X,y_ds,'*')
plt.plot(X,y_r)
plt.show()
train(X,y_ds,w,50,0.001)#auno pongo ma epocas apartir de 5o me sale over flow ,probe cambiando posicion x ,y_s mejoro->(x ,y)en gradiete
print('-------------------------')