diff --git a/paddle2.0_docs/data_loader/data_loader_v3.ipynb b/paddle2.0_docs/data_loader/data_loader_v3.ipynb new file mode 100644 index 00000000..b952832d --- /dev/null +++ b/paddle2.0_docs/data_loader/data_loader_v3.ipynb @@ -0,0 +1,466 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 环境\n", + "本教程基于paddle2.0-alpha编写,如果您的环境不是本版本,请先安装paddle2.0-alpha。" + ] + }, + { + "cell_type": "code", + "execution_count": 295, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2.0.0-alpha0'" + ] + }, + "execution_count": 295, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import paddle\n", + "paddle.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 296, + "metadata": {}, + "outputs": [], + "source": [ + "#数据准备\n", + "#数据处理部分之前的代码,加入部分数据处理的库\n", + "import paddle\n", + "from paddle.imperative import to_variable\n", + "import numpy as np\n", + "import os\n", + "import gzip #解压缩包,python自带的包\n", + "import json\n", + "import random\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1.数据读取与数据集划分\n", + "加载json数据文件。" + ] + }, + { + "cell_type": "code", + "execution_count": 297, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading mnist dataset from /Users/liushuangqiao/Downloads/mnist.json.gz ......\n", + "mnist dataset load done\n", + "训练数据集数量: 50000 50000\n", + "验证数据集数量: 10000 10000\n", + "测试数据集数量: 10000 10000\n" + ] + } + ], + "source": [ + "# 声明数据集文件位置\n", + "datafile = '/Users/liushuangqiao/Downloads/mnist.json.gz'\n", + "print('loading mnist dataset from {} ......'.format(datafile))\n", + "# 加载json数据文件\n", + "data = json.load(gzip.open(datafile))\n", + "print('mnist dataset load done')\n", + "# 读取到的数据区分训练集,验证集,测试集\n", + "train_set, val_set, eval_set = data\n", + "\n", + "# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLS\n", + "IMG_ROWS = 28\n", + "IMG_COLS = 28\n", + "\n", + "# 打印数据信息\n", + "imgs, labels = train_set[0], train_set[1]\n", + "print(\"训练数据集数量: \", len(imgs),len(labels))\n", + "\n", + "# 观察验证集数量\n", + "imgs, labels = val_set[0], val_set[1]\n", + "print(\"验证数据集数量: \", len(imgs),len(labels))\n", + "\n", + "# 观察测试集数量\n", + "imgs, labels = val= eval_set[0], eval_set[1]\n", + "print(\"测试数据集数量: \", len(imgs),len(labels))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. 通过DataSet与DataLoader获取数据" + ] + }, + { + "cell_type": "code", + "execution_count": 298, + "metadata": {}, + "outputs": [], + "source": [ + "from paddle.io import Dataset\n", + "\n", + "#定义Dataset类对象\n", + "class RandomDataset(Dataset):\n", + " def __init__(self, imgs, labels):\n", + " self.imgs = imgs\n", + " self.labels = labels\n", + " \n", + " def __getitem__(self, idx):\n", + " img = self.imgs[idx]\n", + " label = self.labels[idx]\n", + " return img, label\n", + " \n", + " def __len__(self):\n", + " return len(self.imgs)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 299, + "metadata": {}, + "outputs": [], + "source": [ + "#通过DataLoader读取dataset数据,涉及必要参数 :dataset、places=None、batch_size\n", + "def load_data_new(mode='train'):\n", + " datafile = '/Users/liushuangqiao/Downloads/mnist.json.gz'\n", + " print('loading mnist dataset from {} ......'.format(datafile))\n", + " # 定义批大小\n", + " BATCH_SIZE = 64\n", + " # 加载json数据文件\n", + " data = json.load(gzip.open(datafile))\n", + " print('mnist dataset load done')\n", + " # 读取到的数据区分训练集,验证集,测试集\n", + " train_set, val_set, eval_set = data\n", + " if mode=='train':\n", + " # 获得训练数据集\n", + " imgs, labels = train_set[0], train_set[1]\n", + " elif mode=='valid':\n", + " # 获得验证数据集\n", + " imgs, labels = val_set[0], val_set[1]\n", + " elif mode=='eval':\n", + " # 获得测试数据集\n", + " imgs, labels = eval_set[0], eval_set[1]\n", + " else:\n", + " raise Exception(\"mode can only be one of ['train', 'valid', 'eval']\")\n", + " dataset = RandomDataset(imgs, labels)\n", + " loader = paddle.io.DataLoader(dataset, places=paddle.CPUPlace(),batch_size=BATCH_SIZE, drop_last=True)\n", + " return loader" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3. 数据校验" + ] + }, + { + "cell_type": "code", + "execution_count": 300, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading mnist dataset from /Users/liushuangqiao/Downloads/mnist.json.gz ......\n", + "mnist dataset load done\n", + "[64, 784] [64] \n", + "\n", + "打印第一个batch的第一个图像,对应标签数字为[5]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAQbUlEQVR4nO3dfbBU9X3H8fdHHhV8AB8oIlWC+BgjJndQIzV2Eqk6ddTplErbFBktmpHWTGir0mlk0jqjHaMhjtpiJWKqJFhjJDPmwTBOTKaReDEoqFURMIrXi4gKPgQvl2//2AOz6t3fvezu3V3u7/OauXPPnu85e7539cPZ3d/Z/SkiMLOBb59mN2BmjeGwm2XCYTfLhMNulgmH3SwTDrtZJhz2AUrSM5LOanYf1jrkcXazPPjMbpYJh32AkrRB0pckzZd0v6T/lrRN0mpJx0i6VtImSa9Imla23yxJzxXbrpN0+cfu958kdUh6TdJlkkLS0UVtmKSbJP1OUqek/5C0b6P/duuZw56H84HvAqOA3wI/pfTffhzwDeA/y7bdBPwpcAAwC7hF0mcBJJ0DfA34EnA0cNbHjnMDcAwwuaiPA77eH3+Q7Tm/Zh+gJG0ALgOmAmdExNnF+vOBJcCBEdEtaX9gKzAqIt7u4X5+CDwaEQskLQI6I+LaonY08CIwCXgJeBf4TES8VNRPB+6LiAn9+9daXwxudgPWEJ1lyx8AmyOiu+w2wEjgbUnnAtdROkPvA+wHrC62ORxoL7uvV8qWDy22XSlp1zoBg+r0N1iNHHbbTdIw4AHgb4CHIqKrOLPvSm8HcETZLuPLljdT+ofjxIjY2Ih+bc/4NbuVGwoMA94AdhRn+Wll9aXALEnHS9oP+JddhYjYCdxJ6TX+YQCSxkn6k4Z1b0kOu+0WEduAv6cU6reAvwSWldV/DHwbeBRYCzxelLYXv6/etV7SVuDnwLENad565TforGqSjgfWAMMiYkez+7E0n9ltj0i6qBhPHwXcCPzIQd87OOy2py6nNBb/EtANfKW57Vhf+Wm8WSZ8ZjfLREPH2YdqWAxnRCMPaZaV3/MeH8Z29VSrKezFtdILKF0l9V8RcUNq++GM4FR9sZZDmlnCilhesVb103hJg4DbgHOBE4AZkk6o9v7MrH/V8pp9CrA2ItZFxIfA94AL6tOWmdVbLWEfx0c/CPFqse4jJM2W1C6pvWv3hVZm1mj9/m58RCyMiLaIaBvCsP4+nJlVUEvYN/LRTz0dUawzsxZUS9ifACZJmiBpKHAxZR+aMLPWUvXQW0TskDSH0lccDQIWRcQzdevMzOqqpnH2iHgYeLhOvZhZP/LlsmaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulomaZnG1vcA+g5LlwYcd0q+Hf/4fJlSsdY/Ymdz3yImbkvURV6SP3fGtYRVrT7Z9P7nv5u73kvXTls5N1ifOfTxZb4aawi5pA7AN6AZ2RERbPZoys/qrx5n9jyNicx3ux8z6kV+zm2Wi1rAH8DNJKyXN7mkDSbMltUtq72J7jYczs2rV+jR+akRslHQY8Iik/4uIx8o3iIiFwEKAAzQ6ajyemVWppjN7RGwsfm8CHgSm1KMpM6u/qsMuaYSk/XctA9OANfVqzMzqq5an8WOAByXtup/7IuIndelqgBl07NHJegwfkqx3fGFUsv7+6ZXHhEcfmB4v/uXJ6fHmZvrx+/sn6zfefk6yvuKk+yrW1nd9kNz3hs6zk/XDf5m+RqAVVR32iFgHnFzHXsysH3nozSwTDrtZJhx2s0w47GaZcNjNMuGPuNbBzi+ckqzfvPj2ZP2YIUPr2c5eoyu6k/Wvf/uSZH3Ie+kLMj+/dE7F2siNXcl9h21OD83tu/I3yXor8pndLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEx9nrYOjzryXrK38/Plk/ZkhnPdupq7kdpyXr695NfxX13RP/p2LtnZ3pcfIxt/5vst6fBuJXKvnMbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlQhGNG1E8QKPjVH2xYcdrFW9dcnqy/s656a97HvzUyGR91ZW37nFPu/zb5s8k60/80cHJevfWrekDnFb5/tdfpeSuE2Y8lb5v+4QVsZytsaXHB9ZndrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEx5nbwGDDh6drHe/uSVZX7+k8mS6z5y5KLnvlOv/Llk/7Pbmfabc9lxN4+ySFknaJGlN2brRkh6R9GLxOz2BuJk1XV+ext8NfHzW+2uA5RExCVhe3DazFtZr2CPiMeDjzyMvABYXy4uBC+vcl5nVWbXfQTcmIjqK5deBMZU2lDQbmA0wnP2qPJyZ1armd+Oj9A5fxXf5ImJhRLRFRNsQhtV6ODOrUrVh75Q0FqD4val+LZlZf6g27MuAmcXyTOCh+rRjZv2l19fskpYAZwGHSHoVuA64AVgq6VLgZWB6fzY50PU2jt6brq3Vz+9+4l8/m6y/cUf6M+c08DoNq02vYY+IGRVKvjrGbC/iy2XNMuGwm2XCYTfLhMNulgmH3SwTnrJ5ADj+H5+vWJt1UnrQ5DtHLk/Wz/rzK5P1kUsfT9atdfjMbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwuPsA0Bq2uQ3rzguue/vfvRBsn719fck69f+xUXJejx5YMXa+Ov9NdWN5DO7WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJT9mcuS2zTk/W751/U7I+YfDwqo994uI5yfqkO19L1nesf7nqYw9UNU3ZbGYDg8NulgmH3SwTDrtZJhx2s0w47GaZcNjNMuFxdkuKz5+crB9w48Zkfcmnflr1sY979LJk/dj5byfr3WvXV33svVVN4+ySFknaJGlN2br5kjZKWlX8nFfPhs2s/vryNP5u4Jwe1t8SEZOLn4fr25aZ1VuvYY+Ix4AtDejFzPpRLW/QzZH0dPE0f1SljSTNltQuqb2L7TUczsxqUW3Y7wAmApOBDuCblTaMiIUR0RYRbUMYVuXhzKxWVYU9IjojojsidgJ3AlPq25aZ1VtVYZc0tuzmRcCaStuaWWvodZxd0hLgLOAQoBO4rrg9GQhgA3B5RHT0djCPsw88gw49NFl/7eJJFWsrrlmQ3HefXs5Ff7V+WrL+ztQ3k/WBKDXO3uskERExo4fVd9XclZk1lC+XNcuEw26WCYfdLBMOu1kmHHazTPgjrtY0S1/9dbK+n4Ym6+/Hh8n6+XOuqljb94e/Se67t/JXSZuZw26WC4fdLBMOu1kmHHazTDjsZplw2M0y0eun3ixvccbkZH3t9PSUzZ+evKFirbdx9N7cuuWUZH3fh56o6f4HGp/ZzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMeJx9gNPnTkzWX7gqPUvPnWcsTtbPHJ7+THkttkdXsv74lgnpO+j9282z4jO7WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpaJXsfZJY0H7gHGUJqieWFELJA0Gvg+cBSlaZunR8Rb/ddqvgYf9YfJ+kuzjqhYm3/xkuS+fzZyc1U91cO8zrZk/bFvnZasH3RP+nvn7aP6cmbfAcyNiBOA04ArJZ0AXAMsj4hJwPLitpm1qF7DHhEdEfFksbwNeA4YB1wA7Lq8ajFwYX81aWa126PX7JKOAk4BVgBjInZfj/g6paf5Ztai+hx2SSOBB4CvRsTW8lqUJozrcdI4SbMltUtq72J7Tc2aWfX6FHZJQygF/d6I+EGxulPS2KI+FtjU074RsTAi2iKibQjpD12YWf/pNeySBNwFPBcRN5eVlgEzi+WZwEP1b8/M6qUvH3E9A/gysFrSqmLdPOAGYKmkS4GXgen90+Leb/CR45P1rZ87PFmf/q8/SdavOOjBPe6pXuZ2pIfHfn1b5eG10Xc/ntz3oPDQWj31GvaI+BXQ43zPgCdbN9tL+Ao6s0w47GaZcNjNMuGwm2XCYTfLhMNulgl/lXQfDf6Dypf+b/nOyOS+X5nwi2R9xv6dVfVUD3M2Tk3Wf3t7esrmg+9/Olkf/Z7HyluFz+xmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSayGWfvmpb+2uLtX9uSrM87+uGKtWn7vldVT/XS2f1BxdqZy+Ym9z1u3nPJ+qit6XHyncmqtRKf2c0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTGQzzr7+ovS/ay+cdH+/Hfu2tycm6wt+MS1ZV3elb/IuOe4b6yrWJr2xIrlvd7JqA4nP7GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhQR6Q2k8cA9wBgggIURsUDSfOBvgTeKTedFROUPfQMHaHScKs/ybNZfVsRytsaWHi/M6MtFNTuAuRHxpKT9gZWSHilqt0TETfVq1Mz6T69hj4gOoKNY3ibpOWBcfzdmZvW1R6/ZJR0FnALsugZzjqSnJS2SNKrCPrMltUtq72J7Tc2aWfX6HHZJI4EHgK9GxFbgDmAiMJnSmf+bPe0XEQsjoi0i2oYwrA4tm1k1+hR2SUMoBf3eiPgBQER0RkR3ROwE7gSm9F+bZlarXsMuScBdwHMRcXPZ+rFlm10ErKl/e2ZWL315N/4M4MvAakmrinXzgBmSJlMajtsAXN4vHZpZXfTl3fhfAT2N2yXH1M2stfgKOrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpaJXr9Kuq4Hk94AXi5bdQiwuWEN7JlW7a1V+wL3Vq169nZkRBzaU6GhYf/EwaX2iGhrWgMJrdpbq/YF7q1ajerNT+PNMuGwm2Wi2WFf2OTjp7Rqb63aF7i3ajWkt6a+Zjezxmn2md3MGsRhN8tEU8Iu6RxJz0taK+maZvRQiaQNklZLWiWpvcm9LJK0SdKasnWjJT0i6cXid49z7DWpt/mSNhaP3SpJ5zWpt/GSHpX0rKRnJF1VrG/qY5foqyGPW8Nfs0saBLwAnA28CjwBzIiIZxvaSAWSNgBtEdH0CzAknQm8C9wTEZ8u1v07sCUibij+oRwVEVe3SG/zgXebPY13MVvR2PJpxoELgUto4mOX6Gs6DXjcmnFmnwKsjYh1EfEh8D3ggib00fIi4jFgy8dWXwAsLpYXU/qfpeEq9NYSIqIjIp4slrcBu6YZb+pjl+irIZoR9nHAK2W3X6W15nsP4GeSVkqa3exmejAmIjqK5deBMc1spge9TuPdSB+bZrxlHrtqpj+vld+g+6SpEfFZ4FzgyuLpakuK0muwVho77dM03o3SwzTjuzXzsat2+vNaNSPsG4HxZbePKNa1hIjYWPzeBDxI601F3blrBt3i96Ym97NbK03j3dM047TAY9fM6c+bEfYngEmSJkgaClwMLGtCH58gaUTxxgmSRgDTaL2pqJcBM4vlmcBDTezlI1plGu9K04zT5Meu6dOfR0TDf4DzKL0j/xLwz83ooUJfnwKeKn6eaXZvwBJKT+u6KL23cSlwMLAceBH4OTC6hXr7LrAaeJpSsMY2qbeplJ6iPw2sKn7Oa/Zjl+irIY+bL5c1y4TfoDPLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMvH/TswJIRNpLrYAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "#声明数据读取函数,从训练集中读取数据\n", + "paddle.enable_imperative()\n", + "train_loader = load_data_new('train')\n", + "for batch_id, data in enumerate(train_loader()):\n", + "\n", + " image_data, label_data = data[0], data[1] \n", + " if batch_id == 0:\n", + " # 打印数据shape和类型\n", + " print(image_data.shape, label_data.shape, type(image_data), type(label_data))\n", + " print(\"\\n打印第一个batch的第一个图像,对应标签数字为{}\".format(label_data[0].numpy()))\n", + " # 原始数据是归一化后的数据,因此这里需要反归一化\n", + " img = np.array(image_data[0]+1)*127.5\n", + " img = np.reshape(img, [28, 28]).astype(np.uint8)\n", + " plt.figure(\"Image\") # 图像窗口名称\n", + " plt.imshow(img)\n", + " plt.axis('on') # 关掉坐标轴为 off\n", + " plt.title('image') # 图像题目\n", + " plt.show()\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 4. 组成网络" + ] + }, + { + "cell_type": "code", + "execution_count": 301, + "metadata": {}, + "outputs": [], + "source": [ + "from paddle.nn import Conv2D, Pool2D, Linear\n", + "#定义网络结构,这里使用最简单的线性网络\n", + "class Mnist(paddle.nn.Layer):\n", + " def __init__(self, name_scope):\n", + " super(Mnist, self).__init__()\n", + " self.fc = Linear(input_dim=784, output_dim=10, act='softmax', dtype='float64')\n", + "\n", + " # 定义网络结构的前向计算过程\n", + " def forward(self, inputs,label=None):\n", + " outputs = self.fc(inputs)\n", + " if label is not None:\n", + " acc = paddle.metric.accuracy(input=outputs, label=label)\n", + " return outputs, acc\n", + " else:\n", + " return outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5. 训练模型\n", + "在训练模型前,需要设置模型的运行环境,这里我们设置模型在cpu上运行,并将其设置为动态图模式。" + ] + }, + { + "cell_type": "code", + "execution_count": 302, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch: 0, batch: 0, loss is: [3.13129952]\n", + "epoch: 0, batch: 100, loss is: [2.82868815]\n", + "epoch: 0, batch: 200, loss is: [2.5842488]\n", + "epoch: 0, batch: 300, loss is: [3.21580688]\n", + "epoch: 0, batch: 400, loss is: [3.03717391]\n", + "epoch: 0, batch: 500, loss is: [2.84022745]\n", + "epoch: 0, batch: 600, loss is: [2.85783756]\n", + "epoch: 0, batch: 700, loss is: [2.76853633]\n", + "epoch: 1, batch: 0, loss is: [3.13129952]\n", + "epoch: 1, batch: 100, loss is: [2.82868815]\n", + "epoch: 1, batch: 200, loss is: [2.5842488]\n", + "epoch: 1, batch: 300, loss is: [3.21580688]\n", + "epoch: 1, batch: 400, loss is: [3.03717391]\n", + "epoch: 1, batch: 500, loss is: [2.84022745]\n", + "epoch: 1, batch: 600, loss is: [2.85783756]\n", + "epoch: 1, batch: 700, loss is: [2.76853633]\n", + "epoch: 2, batch: 0, loss is: [3.13129952]\n", + "epoch: 2, batch: 100, loss is: [2.82868815]\n", + "epoch: 2, batch: 200, loss is: [2.5842488]\n", + "epoch: 2, batch: 300, loss is: [3.21580688]\n", + "epoch: 2, batch: 400, loss is: [3.03717391]\n", + "epoch: 2, batch: 500, loss is: [2.84022745]\n", + "epoch: 2, batch: 600, loss is: [2.85783756]\n", + "epoch: 2, batch: 700, loss is: [2.76853633]\n", + "epoch: 3, batch: 0, loss is: [3.13129952]\n", + "epoch: 3, batch: 100, loss is: [2.82868815]\n", + "epoch: 3, batch: 200, loss is: [2.5842488]\n", + "epoch: 3, batch: 300, loss is: [3.21580688]\n", + "epoch: 3, batch: 400, loss is: [3.03717391]\n", + "epoch: 3, batch: 500, loss is: [2.84022745]\n", + "epoch: 3, batch: 600, loss is: [2.85783756]\n", + "epoch: 3, batch: 700, loss is: [2.76853633]\n", + "epoch: 4, batch: 0, loss is: [3.13129952]\n", + "epoch: 4, batch: 100, loss is: [2.82868815]\n", + "epoch: 4, batch: 200, loss is: [2.5842488]\n", + "epoch: 4, batch: 300, loss is: [3.21580688]\n", + "epoch: 4, batch: 400, loss is: [3.03717391]\n", + "epoch: 4, batch: 500, loss is: [2.84022745]\n", + "epoch: 4, batch: 600, loss is: [2.85783756]\n", + "epoch: 4, batch: 700, loss is: [2.76853633]\n", + "epoch: 5, batch: 0, loss is: [3.13129952]\n", + "epoch: 5, batch: 100, loss is: [2.82868815]\n", + "epoch: 5, batch: 200, loss is: [2.5842488]\n", + "epoch: 5, batch: 300, loss is: [3.21580688]\n", + "epoch: 5, batch: 400, loss is: [3.03717391]\n", + "epoch: 5, batch: 500, loss is: [2.84022745]\n", + "epoch: 5, batch: 600, loss is: [2.85783756]\n", + "epoch: 5, batch: 700, loss is: [2.76853633]\n", + "epoch: 6, batch: 0, loss is: [3.13129952]\n", + "epoch: 6, batch: 100, loss is: [2.82868815]\n", + "epoch: 6, batch: 200, loss is: [2.5842488]\n", + "epoch: 6, batch: 300, loss is: [3.21580688]\n", + "epoch: 6, batch: 400, loss is: [3.03717391]\n", + "epoch: 6, batch: 500, loss is: [2.84022745]\n", + "epoch: 6, batch: 600, loss is: [2.85783756]\n", + "epoch: 6, batch: 700, loss is: [2.76853633]\n", + "epoch: 7, batch: 0, loss is: [3.13129952]\n", + "epoch: 7, batch: 100, loss is: [2.82868815]\n", + "epoch: 7, batch: 200, loss is: [2.5842488]\n", + "epoch: 7, batch: 300, loss is: [3.21580688]\n", + "epoch: 7, batch: 400, loss is: [3.03717391]\n", + "epoch: 7, batch: 500, loss is: [2.84022745]\n", + "epoch: 7, batch: 600, loss is: [2.85783756]\n", + "epoch: 7, batch: 700, loss is: [2.76853633]\n", + "epoch: 8, batch: 0, loss is: [3.13129952]\n", + "epoch: 8, batch: 100, loss is: [2.82868815]\n", + "epoch: 8, batch: 200, loss is: [2.5842488]\n", + "epoch: 8, batch: 300, loss is: [3.21580688]\n", + "epoch: 8, batch: 400, loss is: [3.03717391]\n", + "epoch: 8, batch: 500, loss is: [2.84022745]\n", + "epoch: 8, batch: 600, loss is: [2.85783756]\n", + "epoch: 8, batch: 700, loss is: [2.76853633]\n", + "epoch: 9, batch: 0, loss is: [3.13129952]\n", + "epoch: 9, batch: 100, loss is: [2.82868815]\n", + "epoch: 9, batch: 200, loss is: [2.5842488]\n", + "epoch: 9, batch: 300, loss is: [3.21580688]\n", + "epoch: 9, batch: 400, loss is: [3.03717391]\n", + "epoch: 9, batch: 500, loss is: [2.84022745]\n", + "epoch: 9, batch: 600, loss is: [2.85783756]\n", + "epoch: 9, batch: 700, loss is: [2.76853633]\n" + ] + } + ], + "source": [ + "# 定义MNIST类的对象,以及优化器\n", + "mnist = Mnist(\"mnist\")\n", + "\n", + "# 定义优化器\n", + "optimizer = paddle.optimizer.Adam(learning_rate=0.1,parameter_list=mnist.parameters())\n", + "\n", + "EPOCH_NUM = 10\n", + "for epoch_id in range(EPOCH_NUM):\n", + " for batch_id, data in enumerate(train_loader()):\n", + " #准备数据\n", + " image_data, label_data = data[0], data[1]\n", + "\n", + " #前向计算的过程\n", + " predict = mnist(image_data)\n", + "\n", + " #计算损失,取一个批次样本损失的平均值\n", + " loss = paddle.nn.functional.cross_entropy(predict,label_data)\n", + " avg_loss = paddle.mean(loss)\n", + "\n", + " #每训练了100批次的数据,打印下当前Loss的情况\n", + " if batch_id % 100 == 0:\n", + " print(\"epoch: {}, batch: {}, loss is: {}\".format(epoch_id, batch_id, avg_loss.numpy()))\n", + "\n", + " #后向传播,更新参数的过程\n", + " avg_loss.backward()\n", + " optimizer.minimize(avg_loss)\n", + " mnist.clear_gradients()\n", + "\n", + "#保存模型参数\n", + "model_dict = mnist.state_dict()\n", + "paddle.imperative.save(model_dict, \"save_temp\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 6. 评估测试" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "mnist_eval = Mnist(\"mnist\") \n", + "model_dict, _ = paddle.imperative.load(\"save_temp\")\n", + "mnist_eval.load_dict(model_dict)\n", + "\n", + "#切换到评估模式\n", + "mnist_eval.eval()\n", + "\n", + "acc_set = []\n", + "avg_loss_set = []\n", + "\n", + "# 定义数据加载器\n", + "test_loader = load_data_new('eval')\n", + "for batch_id, data in enumerate(test_loader()):\n", + " image_data, label_data = data[0],data[1]\n", + " label_data = paddle.reshape(label_data,[-1,1])\n", + " \n", + " #前向计算的过程\n", + " predict, acc = mnist_eval(image_data, label_data)\n", + "\n", + " #计算损失,取一个批次样本损失的平均值\n", + " loss = paddle.nn.functional.cross_entropy(predict,label_data)\n", + " avg_loss = paddle.mean(loss)\n", + " acc_set.append(float(acc.numpy()))\n", + " avg_loss_set.append(float(avg_loss.numpy()))\n", + " \n", + "acc_val_mean = np.array(acc_set).mean()\n", + "avg_loss_val_mean = np.array(avg_loss_set).mean()\n", + "print(\"Eval avg_loss is: {}, acc is: {}\".format(avg_loss_val_mean, acc_val_mean))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}