一尘不染

拆分用于分类问题的数据集的正确程序是什么?

python

我是机器学习和深度学习的新手。我想澄清我与train_test_split训练之前有关的疑问

我有一个size的数据集(302, 100, 5),其中,

(207,100,5) 属于 class 0

(95,100,5) 属于 class 1.

我想使用LSTM执行分类(因为序列数据)

由于各类没有相等的分布集,我如何拆分我的数据集进行训练?

选项1 :考虑整个数据[(302,100, 5) - both classes (0 & 1)],对其进行洗牌,train_test_split,然后进行培训。

选项2: 均等地分割两个班级数据集 [(95,100,5) - class 0 & (95,100,5) - class 1],对其进行洗牌,train_test_split,然后进行训练。

在训练之前进行分割的更好方法是什么,以便我在减少损失,准确性,预测等方面可以获得更好的结果?

如果除了上述2个选项以外,还有其他选择,请推荐,

根据评论部分,我包含了部分数据:

X_train:形状(241 * 100 * 5)

每100 * 5中的每一行对应1个时间步长最后100行对应于100个时间步长(以毫秒为单位)

array([[[0.98620635, 0.        , 0.12752912, 0.60897341, 0.46903766],

        [0.97345112, 0.        , 0.12752912, 0.49205995, 0.38709902],

        [0.9566397 , 0.        , 0.12752912, 0.45728718, 0.42154812],

        ...,

        [0.28669754, 0.8852459 , 0.12752912, 0.8786213 , 0.80125523],

        [0.31559784, 0.8852459 , 0.20968731, 0.89087803, 0.79476987],

        [0.34368841, 0.8852459 , 0.12752912, 0.89087803, 0.71066946]],



       [[0.97957188, 0.14909194, 0.04159147, 0.50548561, 0.34209531],

        [0.9687237 , 0.13964397, 0.04159147, 0.55926067, 0.64613533],

        [0.96596236, 0.13553813, 0.04159147, 0.55903796, 0.85299319],

        ...,

        [0.49309139, 0.72396527, 0.04159147, 0.81998825, 0.12362443],

        [0.52072591, 0.70872926, 0.04159147, 0.82361951, 0.89639432],

        [0.54441507, 0.71835207, 0.04159147, 0.84964602, 1.        ]],



       [[0.48151381, 0.875     , 0.16666667, 0.90637286, 0.62737926],

        [0.53325374, 0.8625    , 0.33333333, 0.87881677, 0.5321154 ],

        [0.57506452, 0.81859091, 0.16666667, 0.84915758, 0.3552661 ],

        ...,

        [0.34456041, 0.92993213, 0.33333333, 0.92953899, 0.78782408],

        [0.39496018, 0.90523485, 0.33333333, 0.9117954 , 0.54579383],

        [0.44187985, 0.8625    , 0.33333333, 0.84163194, 0.25789356]],



       ...,



       [[0.16368355, 0.        , 0.15313225, 0.40101906, 0.36784741],

        [0.15679684, 0.        , 0.15313225, 0.4435126 , 0.67351994],

        [0.15544309, 0.06132052, 0.15313225, 0.40101906, 0.36611345],

        ...,

        [0.43936628, 0.68292683, 0.15313225, 0.82305329, 0.36784741],

        [0.49751546, 0.68292683, 0.07764888, 0.84141109, 0.42828833],

        [0.53288488, 0.68292683, 0.15313225, 0.85959823, 0.36784741]],



       [[0.9418247 , 0.30821318, 0.03072816, 0.744977  , 0.93769733],

        [0.9537216 , 0.28989357, 0.03072816, 0.74576381, 0.98468743],

        [0.96455286, 0.21736423, 0.03072816, 0.74182977, 1.        ],

        ...,

        [0.36273884, 0.60113245, 0.06145633, 0.85409181, 0.32277415],

        [0.38774614, 0.57789971, 0.05844559, 0.82937631, 0.        ],

        [0.41546859, 0.57789971, 0.03072816, 0.79315883, 0.31256578]],



       [[0.97868688, 0.06451613, 0.00411829, 0.64705259, 0.69827586],

        [0.97999663, 0.06451613, 0.02256676, 0.66812232, 0.75195925],

        [0.97143037, 0.02476377, 0.02256676, 0.66317859, 0.78487461],

        ...,

        [0.50336862, 0.73867709, 0.02256676, 0.84921606, 0.1226489 ],

        [0.54003486, 0.72043011, 0.02256676, 0.82679269, 0.20297806],

        [0.57594039, 0.70967742, 0.02256676, 0.83350205, 0.        ]]])

Y_train:形状(241,)

[1. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0.

 1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 1.

 0. 0. 1. 0. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 0.

 0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1.

 1. 0. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1.

 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.

 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0.

 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 1. 1. 0. 1.

 0. 1. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.

 1. 0. 0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1.

 1.]

供参考
,如您在上面看到的,X火车数据很大,我不能包含我的整个X_train数据的完整集合。因此,这里我仅提供数据的一个细分,以更好地了解1个细分的数据(i.e X_train[0] : shape- (100*5))。其余的240或多或少如下图所示

array([[9.86206354e-01, 0.00000000e+00, 1.27529123e-01, 2.29139335e-02,

        6.08973407e-01, 4.69037657e-01],

       [9.73451120e-01, 0.00000000e+00, 1.27529123e-01, 2.60807671e-02,

        4.92059955e-01, 3.87099024e-01],

       [9.56639704e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,

        4.57287179e-01, 4.21548117e-01],

       [9.34897700e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,

        4.84177685e-01, 4.69037657e-01],

       [9.18030989e-01, 0.00000000e+00, 1.27529123e-01, 2.64184174e-02,

        4.86406180e-01, 4.08577406e-01],

       [9.02168015e-01, 0.00000000e+00, 1.27529123e-01, 2.64020795e-02,

        4.84920517e-01, 4.04184100e-01],

       [8.82551572e-01, 0.00000000e+00, 1.27529123e-01, 2.56783096e-02,

        4.51195959e-01, 3.78661088e-01],

       [8.69975342e-01, 0.00000000e+00, 1.27529123e-01, 2.40477851e-02,

        4.70286733e-01, 4.23640167e-01],

       [8.41027241e-01, 0.00000000e+00, 1.27529123e-01, 1.75387576e-02,

        5.04754123e-01, 4.34728033e-01],

       [8.28189535e-01, 5.28763040e-01, 1.27529123e-01, 6.89133486e-03,

        4.98662903e-01, 4.58368201e-01],

       [8.21784739e-01, 8.21162444e-01, 1.27529123e-01, 1.06196483e-02,

        5.87431288e-01, 5.72594142e-01],

       [8.26651597e-01, 9.96721311e-01, 1.27529123e-01, 1.75044480e-02,

        6.89050661e-01, 5.40376569e-01],

       [8.42115326e-01, 1.00000000e+00, 1.27529123e-01, 1.71205069e-02,

        8.35388501e-01, 4.69037657e-01],

       [8.64071009e-01, 9.26875310e-01, 1.27529123e-01, 1.34068975e-02,

        1.00000000e+00, 4.65062762e-01],

       [8.79579724e-01, 7.60158967e-01, 1.27529123e-01, 4.65303975e-03,

        9.61744169e-01, 3.65481172e-01],

       [9.03630040e-01, 7.61549925e-01, 1.27529123e-01, 4.21518348e-03,

        9.22076957e-01, 3.78033473e-01],

       [9.18435858e-01, 6.72429210e-01, 1.27529123e-01, 2.70229205e-03,

        9.39979201e-01, 5.03138075e-01],

       [9.29983046e-01, 6.85345256e-01, 1.27529123e-01, 9.05120794e-04,

        8.53736443e-01, 5.52510460e-01],

       [9.48081232e-01, 5.78539493e-01, 1.27529123e-01, 6.96485550e-03,

        8.84415391e-01, 3.04602510e-01],

       [9.48112160e-01, 5.55091903e-01, 1.27529123e-01, 1.10493356e-02,

        8.19046204e-01, 4.78661088e-01],

       [9.61281634e-01, 5.08693492e-01, 1.27529123e-01, 9.36162843e-03,

        8.23651761e-01, 3.21548117e-01],

       [9.72179346e-01, 4.91803279e-01, 1.27529123e-01, 9.82725917e-03,

        7.57391175e-01, 4.96025105e-01],

       [9.84752763e-01, 4.91803279e-01, 1.27529123e-01, 7.04491131e-03,

        7.59322538e-01, 3.95397490e-01],

       [9.90300024e-01, 4.91803279e-01, 1.27529123e-01, 8.19346712e-03,

        7.64819492e-01, 4.69037657e-01],

       [9.88306609e-01, 3.77049180e-01, 1.27529123e-01, 8.62642201e-03,

        7.93492795e-01, 4.16945607e-01],

       [9.91084457e-01, 3.93442623e-01, 1.27529123e-01, 9.16557339e-03,

        7.10741346e-01, 4.72175732e-01],

       [1.00000000e+00, 3.78936910e-01, 1.27529123e-01, 1.16538387e-02,

        6.93359085e-01, 4.76987448e-01],

       [9.98925974e-01, 3.93442623e-01, 1.27529123e-01, 1.21309060e-02,

        7.16609716e-01, 3.46025105e-01],

       [9.92838888e-01, 3.32141083e-01, 1.27529123e-01, 1.19315833e-02,

        7.31540633e-01, 4.16527197e-01],

       [9.90637415e-01, 3.36910084e-01, 1.27529123e-01, 9.95632874e-03,

        7.12524142e-01, 4.15481172e-01],

       [9.90761125e-01, 3.38301043e-01, 1.27529123e-01, 6.59235091e-03,

        6.86970732e-01, 4.37656904e-01],

       [9.90274720e-01, 3.27868852e-01, 2.10913550e-01, 5.68396253e-03,

        7.09181399e-01, 4.99372385e-01],

       [9.83015202e-01, 3.27868852e-01, 1.27529123e-01, 2.14974358e-02,

        7.31392067e-01, 6.41631799e-01],

       [9.77392028e-01, 2.85245902e-01, 1.47762109e-01, 2.52861995e-02,

        7.09478532e-01, 6.07112971e-01],

       [9.75300207e-01, 2.78688525e-01, 1.27529123e-01, 2.91468501e-02,

        6.70257020e-01, 6.28242678e-01],

       [9.74917831e-01, 2.71733731e-01, 1.27529123e-01, 3.58780734e-02,

        6.70257020e-01, 5.72594142e-01],

       [9.64950755e-01, 2.62295082e-01, 1.27529123e-01, 3.92992339e-02,

        6.36383895e-01, 6.67991632e-01],

       [9.63159774e-01, 2.62295082e-01, 1.27529123e-01, 4.82932591e-02,

        6.93581934e-01, 5.46443515e-01],

       [9.54983679e-01, 2.90511674e-01, 1.27529123e-01, 4.90627752e-02,

        6.59708810e-01, 7.40376569e-01],

       [9.57595643e-01, 3.11475410e-01, 1.27529123e-01, 4.72492660e-02,

        6.49977715e-01, 5.61297071e-01],

       [9.51511369e-01, 2.95081967e-01, 1.27529123e-01, 1.82576261e-02,

        6.64314366e-01, 5.22384937e-01],

       [9.48528275e-01, 2.95081967e-01, 1.27529123e-01, 3.89659403e-03,

        6.29846977e-01, 3.20711297e-01],

       [9.47085931e-01, 2.95081967e-01, 1.27529123e-01, 6.86682798e-03,

        6.48417769e-01, 4.38284519e-01],

       [9.38153518e-01, 2.95081967e-01, 1.27529123e-01, 5.73951146e-03,

        7.04130144e-01, 5.32635983e-01],

       [9.38114156e-01, 2.95081967e-01, 1.27529123e-01, 2.05955826e-02,

        6.85782202e-01, 5.47280335e-01],

       [9.35597786e-01, 2.95081967e-01, 1.27529123e-01, 2.91141743e-02,

        6.69142772e-01, 7.13807531e-01],

       [9.29311077e-01, 2.72826627e-01, 1.27529123e-01, 2.91141743e-02,

        6.81622344e-01, 5.72594142e-01],

       [9.25495753e-01, 2.23646299e-01, 1.27529123e-01, 2.65507546e-02,

        6.35566781e-01, 6.41004184e-01],

       [9.18525829e-01, 2.08643815e-03, 1.27529123e-01, 2.37618715e-02,

        6.09641955e-01, 5.02928870e-01],

       [8.91801693e-01, 0.00000000e+00, 1.27529123e-01, 9.27013608e-03,

        5.26073392e-01, 4.21338912e-01],

       [8.77693149e-01, 0.00000000e+00, 1.27529123e-01, 8.13628440e-03,

        4.22522656e-01, 3.44560669e-01],

       [8.61894841e-01, 0.00000000e+00, 1.27529123e-01, 1.49639014e-02,

        4.52755906e-01, 3.65481172e-01],

       [8.44254943e-01, 0.00000000e+00, 1.27529123e-01, 2.29515107e-02,

        4.59069975e-01, 3.76150628e-01],

       [8.21183060e-01, 0.00000000e+00, 1.27529123e-01, 3.97583295e-02,

        4.60852771e-01, 2.60460251e-01],

       [8.04116726e-01, 0.00000000e+00, 1.27529123e-01, 5.89292454e-02,

        4.26905363e-01, 1.97907950e-01],

       [7.81311943e-01, 0.00000000e+00, 1.27529123e-01, 8.53656345e-02,

        4.37379290e-01, 1.00836820e-01],

       [7.60863270e-01, 0.00000000e+00, 1.27529123e-01, 1.03087377e-01,

        4.37379290e-01, 6.98744770e-02],

       [7.41227145e-01, 0.00000000e+00, 1.27529123e-01, 1.14206966e-01,

        4.27128213e-01, 1.58368201e-01],

       [7.26694052e-01, 0.00000000e+00, 1.27529123e-01, 1.17776801e-01,

        4.37379290e-01, 0.00000000e+00],

       [7.08716764e-01, 0.00000000e+00, 1.27529123e-01, 1.17288297e-01,

        4.48596048e-01, 2.18619247e-01],

       [6.90483621e-01, 0.00000000e+00, 1.27529123e-01, 1.08491961e-01,

        4.58549993e-01, 1.26987448e-01],

       [6.67451099e-01, 0.00000000e+00, 1.27529123e-01, 8.38217010e-02,

        4.99628584e-01, 3.55020921e-01],

       [6.51610618e-01, 0.00000000e+00, 1.27529123e-01, 4.32889541e-02,

        5.10919626e-01, 4.83054393e-01],

       [6.31195684e-01, 0.00000000e+00, 1.27529123e-01, 1.29200275e-02,

        5.21170703e-01, 4.97907950e-01],

       [6.14317726e-01, 0.00000000e+00, 2.26241570e-01, 9.32895259e-04,

        4.98960036e-01, 4.69037657e-01],

       [5.98165158e-01, 0.00000000e+00, 5.90435316e-01, 0.00000000e+00,

        4.61892735e-01, 5.03556485e-01],

       [5.68221755e-01, 0.00000000e+00, 6.33353771e-01, 1.61745413e-03,

        4.25122567e-01, 4.69037657e-01],

       [5.35292447e-01, 0.00000000e+00, 1.00000000e+00, 8.99402522e-03,

        3.58490566e-01, 5.10041841e-01],

       [5.10766973e-01, 0.00000000e+00, 3.93010423e-01, 3.39894098e-02,

        3.27068786e-01, 6.15690377e-01],

       [4.78939807e-01, 0.00000000e+00, 5.32188841e-01, 5.98114931e-02,

        3.27068786e-01, 6.22175732e-01],

       [4.47053597e-01, 0.00000000e+00, 4.31023912e-01, 8.44245703e-02,

        3.24023176e-01, 6.76150628e-01],

       [4.13654754e-01, 0.00000000e+00, 5.32188841e-01, 1.07209434e-01,

        2.90298618e-01, 7.08577406e-01],

       [3.80151882e-01, 0.00000000e+00, 7.97057020e-01, 1.21122807e-01,

        1.19150201e-01, 4.95397490e-01],

       [3.28235926e-01, 0.00000000e+00, 3.56223176e-01, 1.23820198e-01,

        0.00000000e+00, 6.65271967e-01],

       [2.83452966e-01, 0.00000000e+00, 2.28694053e-01, 1.22658572e-01,

        2.65933739e-02, 5.55648536e-01],

       [2.38616587e-01, 0.00000000e+00, 2.28694053e-01, 1.22990232e-01,

        9.41910563e-02, 4.92887029e-01],

       [1.82964031e-01, 0.00000000e+00, 5.19926426e-01, 1.30564491e-01,

        8.97340663e-02, 4.94142259e-01],

       [1.43835174e-01, 0.00000000e+00, 5.25444513e-01, 1.64135650e-01,

        1.14618927e-01, 7.40585774e-01],

       [1.04402664e-01, 0.00000000e+00, 1.55119559e-01, 2.41378071e-01,

        1.98261774e-01, 6.50418410e-01],

       [7.96438281e-02, 0.00000000e+00, 7.11220110e-02, 3.27145618e-01,

        2.89110088e-01, 7.45188285e-01],

       [6.36065353e-02, 0.00000000e+00, 0.00000000e+00, 4.11129065e-01,

        4.05140395e-01, 6.88912134e-01],

       [4.11672585e-02, 0.00000000e+00, 2.52605763e-01, 5.62182942e-01,

        4.54315852e-01, 1.00000000e+00],

       [2.87063044e-02, 0.00000000e+00, 1.27529123e-01, 6.81786323e-01,

        4.59515674e-01, 9.32217573e-01],

       [1.70269716e-02, 1.58966716e-03, 1.27529123e-01, 7.33474602e-01,

        4.37453573e-01, 6.07322176e-01],

       [3.30361486e-03, 6.37853949e-01, 1.27529123e-01, 8.06276376e-01,

        4.69692468e-01, 7.54602510e-01],

       [0.00000000e+00, 7.89369101e-01, 1.27529123e-01, 8.85843682e-01,

        5.10919626e-01, 8.70502092e-01],

       [5.13114648e-03, 8.19672131e-01, 1.27529123e-01, 9.60932765e-01,

        5.99316595e-01, 8.79288703e-01],

       [2.16829598e-02, 8.36065574e-01, 1.27529123e-01, 9.99121020e-01,

        7.28866439e-01, 8.56903766e-01],

       [4.27951674e-02, 8.36065574e-01, 1.27529123e-01, 1.00000000e+00,

        8.67181697e-01, 7.88912134e-01],

       [7.02334461e-02, 8.36065574e-01, 1.27529123e-01, 9.93500775e-01,

        8.46308127e-01, 9.78451883e-01],

       [9.73680733e-02, 8.36065574e-01, 1.27529123e-01, 9.87896869e-01,

        8.66364582e-01, 8.59414226e-01],

       [1.23611427e-01, 8.36065574e-01, 1.27529123e-01, 9.69613102e-01,

        8.35685634e-01, 9.17991632e-01],

       [1.52157471e-01, 8.68852459e-01, 1.27529123e-01, 9.22226597e-01,

        7.96686971e-01, 9.65062762e-01],

       [1.77979087e-01, 8.68852459e-01, 1.27529123e-01, 8.61132577e-01,

        8.29594414e-01, 8.14225941e-01],

       [2.03010647e-01, 8.84252360e-01, 1.27529123e-01, 8.13277174e-01,

        8.29594414e-01, 9.11506276e-01],

       [2.32490138e-01, 8.85245902e-01, 1.27529123e-01, 7.59549923e-01,

        8.41851137e-01, 9.52301255e-01],

       [2.58952796e-01, 8.85245902e-01, 1.27529123e-01, 6.97804020e-01,

        8.55667806e-01, 8.68200837e-01],

       [2.86697538e-01, 8.85245902e-01, 1.27529123e-01, 6.25149288e-01,

        8.78621304e-01, 8.01255230e-01],

       [3.15597842e-01, 8.85245902e-01, 2.09687308e-01, 5.51940700e-01,

        8.90878027e-01, 7.94769874e-01],

       [3.43688409e-01, 8.85245902e-01, 1.27529123e-01, 4.75801089e-01,

        8.90878027e-01, 7.10669456e-01]])

阅读 213

收藏
2021-01-20

共1个答案

一尘不染

TLDR: 两者都尝试!


在数据集不平衡之前,我曾遇到过类似情况。我使用train_test_splitKFold通过。

但是,一旦我偶然发现了处理不平衡数据集的问题,便遇到了过度平衡和欠平衡的技术。为此,我建议使用库:imblearn

您将在其中找到各种技巧来处理其中一个类别的人数超过另一个类别的情况。我个人经常使用SMOTE,并且在这种情况下取​​得了相对较好的成功。


其他参考:

https://www.analyticsvidhya.com/blog/2017/03/imbalanced-classification-
problem/

https://towardsdatascience.com/handling-imbalanced-datasets-in-machine-
learning-7a0e84220f28

2021-01-20