python中sklearn实现交叉验证

xiaoxiao2021-02-27  414

质量要比数量重要,就像一个本垒打胜过两个双打。——《蚂蚁金服》

1、概述

在实验数据分析中,有些算法需要用现有的数据构建模型,如卷积神经网络(CNN),这类算法称为监督学习(Supervisied Learning)。构建模型需要的数据称为训练数据。

模型构建完后,需要利用数据验证模型的正确性,这部分数据称为测试数据。测试数据不能用于构建模型中,只能用于最后检验模型的准确性。

有时候模型的构建的过程中,也需要检验模型,辅助模型构建。所以会将训练数据分为两个部分,1)训练数据;2)验证数据。 将数据分类就要采用交叉验证的方法,个人写的交叉验证算法不可避免有一定缺陷,考虑使用强大sklearn包可以实现交叉验证算法。

2、python实现

请注意:以下的方法实现根据最新的sklearn版本实现,老版本的函数很多已经过期。

2.1 K次交叉检验(K-Fold Cross Validation)

K次交叉检验的大致思想是将数据大致分为K个子样本,每次取一个样本作为验证数据,取余下的K-1个样本作为训练数据。

from sklearn.model_selection import KFold import numpy as np X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) y = np.array([1, 2, 3, 4]) kf = KFold(n_splits=2) for train_index, test_index in kf.split(X): print("TRAIN:", train_index, "TEST:", test_index) X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index]

2.2 Stratified k-fold

StratifiedKFold()这个函数较常用,比KFold的优势在于将k折数据按照百分比划分数据集,每个类别百分比在训练集和测试集中都是一样,这样能保证不会有某个类别的数据在训练集中而测试集中没有这种情况,同样不会在训练集中没有全在测试集中,这样会导致结果糟糕透顶。

from sklearn.model_selection import StratifiedKFold import numpy as np X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]]) y = np.array([0, 0, 1, 1]) skf = StratifiedKFold(n_splits=2) for train_index, test_index in skf.split(X, y): print("TRAIN:", train_index, "TEST:", test_index) X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index]

2.3 train_test_split

随机根据比例分配训练集和测试集。这个函数可以调整随机种子。

import numpy as np from sklearn.model_selection import train_test_split X, y = np.arange(10).reshape((5, 2)), range(5) X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.33, random_state=42)

3、总结

sklearn功能非常强大提供的交叉验证函数也非常多,如:

'BaseCrossValidator', 'GridSearchCV', 'TimeSeriesSplit', 'KFold', 'GroupKFold', 'GroupShuffleSplit', 'LeaveOneGroupOut', 'LeaveOneOut', 'LeavePGroupsOut', 'LeavePOut', 'ParameterGrid', 'ParameterSampler', 'PredefinedSplit', 'RandomizedSearchCV', 'ShuffleSplit', 'StratifiedKFold', 'StratifiedShuffleSplit', 'check_cv', 'cross_val_predict', 'cross_val_score', 'fit_grid_point', 'learning_curve', 'permutation_test_score', 'train_test_split', 'validation_curve'

感兴趣的可以查看sklearn源码。

转载请注明原文地址: https://www.6miu.com/read-2609.html

最新回复(0)