TODO:
第一篇介绍numpy库和matplotlib库、读写二进制的方法、pkl等。这些知识会在后面用到,在本篇的最后会以mnist数据集为例,创建处理手写体图片的函数,供后使用。
本篇思维导图如下
本篇是基础部分,本着简洁的目的,将需要了解的资料贴出来。不浪费太多时间讲解,反正后面会用到。
[TOC]
1. python基础知识
1.1 class 和function
详情请看链接
python3 函数
1.2 numpy
在深度学习的实现中会使用矩阵进行计算,numpy中实现了很多数据组的运算方法,在后期会用到。
1.2.1 Ndarray
Numpy中主要的数据结构是Ndarray,用于存放同类型元素的多维数组。
数据类型:dtype,描述数据类型,可以计算每个元素大小;
数组形状:shape,描述数组的大小和形状;
跨度元组,stride:表示从前一个维度到下一个维度需要跨越的字节数;
data: 指向数组的地址
ps: 后期会用到dtype, shape等成员变量
1.2.2 切片和索引
numpy的切片和索引的有关内容在 fancy-indexing-and-index-tricks 中可以找到。
1.2.3 广播机制
Numpy对于不同形状的乘法采用了广播机制。广播是一种ufunc的机制是 不同形状的数组之间执行算数运算的方式,需要遵循4个原则:
1.让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐
2.输入数组的shape是输入数组shape的各个轴上的最大值
3.如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错。
4.输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值。
广播可以对不同形状的数组做点乘:将较小的形状按照一定的规则填充,填充的方向依次为由内向外;广播机制在cudnn、tensorflow等深度学习框架中同样会使用。
如下图所示:
第一个矩阵是22, 但第二个矩阵并不是22的,按照数学运算法则是不能做点乘的;
但如果有广播机制,会按照以下方式填充数据,并做乘法:
更多关于广播机制,详见: basics.broadcasting
1.2.3 numpy的其他知识点
numpy的其他知识点详见
numpy-quickstart.html
1.3 matplotlib和skimage
1.4 序列化
1.5 mnist数据集处理
# -*- coding: utf-8 -*-
# @File : day7.py
# @Author: lizhen
# @Date : 2020/2/4
# @Desc :
import urllib.request # python3
import os.path
import gzip
import pickle
import os
import numpy as np
# http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
# http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
# http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
# http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
url_base = "http://yann.lecun.com/exdb/mnist/"
key_file = {
'train_img':'train-images-idx3-ubyte.gz',
'train_label':'train-labels-idx1-ubyte.gz',
'test_img':'t10k-images-idx3-ubyte.gz',
'test_label':'t10k-labels-idx1-ubyte.gz'
}
dataset_dir=os.path.dirname(os.path.abspath(__file__))
save_file=dataset_dir + "/mnist.pkl"
train_num = 60000;
test_num = 10000;
img_dim = (1, 28, 28)
img_size = 28*28;
def _download(file_name):
"""
:param file_name: 下载mnist的文件
:return: null
"""
file_path = os.path.join(dataset_dir, file_name)
if os.path.exists(file_path):
return
print("downloading"+file_name+ "...")
urllib.request.urlretrieve(url_base + file_name , file_path)
print("Done.")
def download_mnist():
"""
:return:
"""
for file_name in key_file.values():
_download(file_name);
def _load_label(file_name):
"""
解析标签
:param file_name:
:return:
"""
file_path = dataset_dir+'/'+ file_name
print("converting "+file_name+" to numpy Array.")
with gzip.open(file_path) as f:
labels = np.frombuffer(f.read(), np.uint8, offset=8)
print("Done")
return labels
def _load_img(file_name):
"""
解析 压缩的图片
:param file_name:
:return:
"""
file_path = dataset_dir +'/' + file_name
print("converting "+ file_name + "to numpy Array")
with gzip.open(file_path) as f:
data = np.frombuffer(f.read(), np.uint8, offset=16) # 16*8=
data = data.reshape(-1, img_size) # N, (W*H*C)=[N,28*28*1]
print("Done")
return data
def _convert_numpy():
"""
解析 image和label,将其转换为numpy
"""
dataset = {}
dataset['train_img'] = _load_img(key_file['train_img'])
dataset['train_label'] = _load_label(key_file['train_label'])
dataset['test_img'] = _load_img(key_file['test_img'])
dataset['test_label'] = _load_label(key_file['test_label'])
return dataset
def init_mnist():
"""
初始化mnist数据集:
1. 下载mnist,
2. 以二进制的方式读取,并转换成numpy的ndarray对象
3. 将转换后的ndarray 序列化
:return:
"""
print("download mnist dataset...")
download_mnist()
print("convert to numpy array...")
dataset = _convert_numpy()
print("creating pickle file ...")
with open(save_file, 'wb') as f:
pickle.dump(dataset, f, -1)
print("Done!")
def _change_one_hot_label(Y):
T = np.zeros((Y.size,10))
for idx,row in enumerate(T):
row[Y[idx]] = 1
return T
def load_mnist(normalize=True, flatten=True, one_hot_label=False):
"""
:param normalize: 将数据标准化到0.0~1.0
:param flatten: 是否要将数据拉伸层1D数组的形式
:param one_hot_label:
:return: (训练数据, 训练标签), (测试数据, 测试label)
"""
if not os.path.exists(save_file):
init_mnist()
with open(save_file,'rb') as f:
dataset = pickle.load(f)
if normalize:
for key in ('train_img','test_img'):
dataset[key] = dataset[key].astype(np.float32)
dataset[key] /=255.0
if one_hot_label:
dataset['train_label'] = _change_one_hot_label(dataset['train_label'])
dataset['test_label'] = _change_one_hot_label(dataset['test_label'])
if not flatten:
for key in ('train_img', 'test_img'):
dataset[key] = dataset[key].reshape(-1,1,28,28) # NCHW
return (dataset['train_img'],dataset['train_label']),(dataset['test_img'], dataset['test_label'])
if __name__ == '__main__':
init_mnist()