我打算在Spark中使用线性回归。首先,我从官方文档中查看了示例(您可以在此处找到)
我也发现了这个问题,该问题与我的问题基本上相同。答案建议调整步长,这也是我尝试做的,但是结果仍然像不调整步长一样随机。我正在使用的代码如下所示:
from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel # Load and parse the data def parsePoint(line): values = [float(x) for x in line.replace(',', ' ').split(' ')] return LabeledPoint(values[0], values[1:]) data = sc.textFile("data/mllib/ridge-data/lpsa.data") parsedData = data.map(parsePoint) # Build the model model = LinearRegressionWithSGD.train(parsedData,100000,0.01) # Evaluate the model on training data valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() print("Mean Squared Error = " + str(MSE))
结果如下:
(Expected Label, Predicted Label) (-0.4307829, -0.7824231588143065) (-0.1625189, -0.6234287565006766) (-0.1625189, -0.41979307020176226) (-0.1625189, -0.6517649080382241) (0.3715636, -0.38543073492870156) (0.7654678, -0.7329426818746223) (0.8544153, -0.33273378445315) (1.2669476, -0.36663240056848917) (1.2669476, -0.47541427992967517) (1.2669476, -0.1887811811672498) (1.3480731, -0.28646712964591936) (1.446919, -0.3425075015127807) (1.4701758, -0.14055275401870437) (1.4929041, -0.06819303631450688) (1.5581446, -0.772558163357755) (1.5993876, -0.19251656391040356) (1.6389967, -0.38105697301968594) (1.6956156, -0.5409707504639943) (1.7137979, 0.14914490255841997) (1.8000583, -0.0008818203337740971) (1.8484548, 0.06478505759587616) (1.8946169, -0.0685096804502884) (1.9242487, -0.14607596025743624) (2.008214, -0.24904211817187422) (2.0476928, -0.4686214015035236) (2.1575593, 0.14845590638215034) (2.1916535, -0.5140996125798528) (2.2137539, 0.6278134417345228) (2.2772673, -0.35049969044209983) (2.2975726, -0.06036824276546304) (2.3272777, -0.18585219083806218) (2.5217206, -0.03167349168036536) (2.5533438, -0.1611040092884861) (2.5687881, 1.1032200139582564) (2.6567569, 0.04975777739217784) (2.677591, -0.01426285133724671) (2.7180005, 0.07853368755223371) (2.7942279, -0.4071930969456503) (2.8063861, 0.000492545291049501) (2.8124102, -0.019947344959659177) (2.8419982, 0.03023139779978133) (2.8535925, 0.5421291261646886) (2.9204698, 0.3923068894170366) (2.9626924, 0.21639267973240908) (2.9626924, -0.22540434628281075) (2.9729753, 0.2363938458250126) (3.0130809, 0.35136961387278565) (3.0373539, 0.013876918415846595) (3.2752562, 0.49970959078043126) (3.3375474, 0.5436323480304672) (3.3928291, 0.48746004196839055) (3.4355988, 0.3350764608584778) (3.4578927, 0.6127634045652381) (3.5160131, -0.03781697409079157) (3.5307626, 0.2129806543371961) (3.5652984, 0.5528805608876549) (3.5876769, 0.06299042506665305) (3.6309855, 0.5648082098866389) (3.6800909, -0.1588172848952902) (3.7123518, 0.1635062564072022) (3.9843437, 0.7827244309795267) (3.993603, 0.6049246406551748) (4.029806, 0.06372113813964088) (4.1295508, 0.24281029469705093) (4.3851468, 0.5906868686740623) (4.6844434, 0.4055055537895428) (5.477509, 0.7335244827296759) Mean Squared Error = 6.83550847274
那么,我想念什么?由于数据来自官方的spark文档,我猜想它应该适合对其应用线性回归(并至少得到一个合理的好预测)?
对于初学者,您缺少拦截。尽管自变量的平均值接近零:
parsedData.map(lambda lp: lp.features).mean() ## DenseVector([-0.031, -0.0066, 0.1182, -0.0199, 0.0178, -0.0249, ## -0.0294, 0.0669]
因变量的平均值离它很远:
parsedData.map(lambda lp: lp.label).mean() ## 2.452345085074627
在这种情况下,强制回归线穿过原点是没有意义的。因此,让我们看看如何LinearRegressionWithSGD使用默认参数和添加的拦截执行:
LinearRegressionWithSGD
model = LinearRegressionWithSGD.train(parsedData, intercept=True) valuesAndPreds = (parsedData.map(lambda p: (p.label, model.predict(p.features)))) valuesAndPreds.map(lambda vp: (vp[0] - vp[1]) ** 2).mean() ## 0.44005904185432504
让我们将其与分析解决方案进行比较
import numpy as np from sklearn import linear_model features = np.array(parsedData.map(lambda lp: lp.features.toArray()).collect()) labels = np.array(parsedData.map(lambda lp: lp.label).collect()) lm = linear_model.LinearRegression() lm.fit(features, labels) np.mean((lm.predict(features) - labels) ** 2) ## 0.43919976805833411
尽您所能,使用获得的结果LinearRegressionWithSGD几乎是最佳的。
您可以添加网格搜索,但是在这种特殊情况下,可能没有任何收获。