Keras基于Cifar-10数据集的CNN实现

xiaoxiao2021-02-27  456

一、网络结构

        模型采用Keras中的序列模型实现,共六层,前三层为卷积层,四五层为全连接层,第六层为softmax输出层。卷积层核数量分别为32@5x5、32@5x5、64@5x5,全连接层节点数为1024和256,输出层为节点数为10。如附图1所示。

二、激活函数

        Keras中内置激活函数有Sigmoid、tanh、relu、softmax等等,可直接用activation=‘relu’实现;高级激活函数有LeakyRelu、PRelu、ELU、ThresholdedRelu,需import advance_activation实现。本模型采用LeakeyRelu函数,参数为0.2。

三、权重初始化

        参考Xavier Glorot(2010)和Kaiming He(2015),采用He(2015)的方法,即初始参数由0均值,标准差为sqrt(2 / n) 的正态分布产生,n为输入层神经元数量。

四、BatchNormalization

        参考《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》,将BN操作置于卷积操作与激活函数之间。

五、优化器

        Keras内置优化器有SGD、RMSprop、Adagrad、Adadelta、Adam、Adamax、Nadam等。本模型采用Adam方法,参数为默认参数。

六、数据预处理

        训练数据集预处理有两点:去均值和方差归一化。对测试数据集进行预测时应减去训练数据集的均值,然后方差归一化。去均值操作为:trainData -= numpy.mean(trainData, axis=0);方差归一化操作为:trainData /= numpy.std(trainData, axis=0)。

        图像数据通道位置为channels_first或者channels_last与后端实现有关。Theano后端采用channels_first,trainData的维度为(numData, channels, Img_W, Img_H),这时BN层参数为axis=1;Tensorflow后端采用channels_last,trainData维度为(numData, Img_W, Img_H, channels),这时BN层参数为axis=-1。

七、训练结果

        20轮训练精度为0.97。

八、代码及网络结构图

import os import cv2 import numpy import matplotlib.pyplot as plt from keras.models import Sequential from keras.layers import Dense, Activation from keras.layers import Conv2D, MaxPooling2D, Flatten, AveragePooling2D, BatchNormalization, advanced_activations from keras.layers import initializers from keras.optimizers import SGD, Adam from keras.utils import np_utils from keras.utils.vis_utils import plot_model def loadData(path): data = [] labels = [] for i in range(10): dir1 = './'+path+'/'+str(i) listImg = os.listdir(dir1) for img in listImg: imgIn = cv2.imread(dir1+'/'+img) if imgIn.size != 3072: print 'Img error' data.append(imgIn) # data.append([numpy.array(Image.open(dir+'/'+img))]) labels.append(i) print path, i, 'is read' return data, labels trainData, trainLabels = loadData('train') #testData, testLabels = loadData('test1') trainLabels = np_utils.to_categorical(trainLabels, 10) #testLabels = np_utils.to_categorical(testLabels, 10) trainData = numpy.reshape(trainData, (len(trainData), 32, 32,3)) trainData = trainData.astype(numpy.float32) trainData -= numpy.mean(trainData, axis=0) trainData /= numpy.std(trainData, axis=0) #print trainData[-1] model = Sequential() model.add(Conv2D(filters=32, kernel_size=(5,5), padding='same', input_shape=(32,32,3), data_format='channels_last', kernel_initializer=initializers.he_normal())) model.add(BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-6)) model.add(Activation(advanced_activations.LeakyReLU(alpha=0.2))) model.add(MaxPooling2D(pool_size=(2,2))) model.add(Conv2D(filters=32, kernel_size=(5,5), padding='same', data_format='channels_last', kernel_initializer=initializers.he_normal())) model.add(BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-6)) model.add(Activation(advanced_activations.LeakyReLU(alpha=0.2))) model.add(AveragePooling2D(pool_size=(2,2))) model.add(Conv2D(filters=64, kernel_size=(5,5), padding='same', data_format='channels_last', kernel_initializer=initializers.he_normal())) model.add(BatchNormalization(axis=-1, momentum=0.9, epsilon=1e-6)) model.add(Activation(advanced_activations.LeakyReLU(alpha=0.2))) model.add(AveragePooling2D(pool_size=(2,2))) model.add(Flatten()) model.add(Dense(1024, kernel_initializer=initializers.he_normal())) model.add(BatchNormalization(epsilon=1e-6)) model.add(Activation(advanced_activations.LeakyReLU(alpha=0.2))) model.add(Dense(256, activation=advanced_activations.LeakyReLU(alpha=0.1), kernel_initializer=initializers.he_normal())) model.add(Dense(10, activation='softmax', kernel_initializer=initializers.he_normal())) adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-8) model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy']) result = model.fit(trainData, trainLabels, batch_size=250, epochs=20, verbose=1, shuffle=True) plot_model(model, to_file='model.png', show_shapes=True, show_layer_names=False) plt.figure() plt.plot(result.epoch, result.history['acc'], label='acc') plt.scatter(result.epoch, result.history['acc'], marker='*') plt.legend(loc='right') plt.show()

转载请注明原文地址: https://www.6miu.com/read-659.html

最新回复(0)