数据集中含有10072个图片文件和10072个图片所对应的包含图片中中文字内容的文本文。
task:
1.得到图片数据集中所有的中文字符,构成字符字典,字典大小为所包含不同中文字符的类别数;(dict_size=992,加上一个“空白”,在CTC中一共含有992+1=993个类别)
2.构建训练数据 train_x,train_y; train_x中每一个元素为一张图片(cv2.imread()读取的灰度图),train_y 中每一个元素为图片对应的文字在字符字典中的序号;
code:
def get_char_dict(path):
char_dict = []
txt_files = glob.glob(path + '*.txt')
# print(len(txt_files))
for file in txt_files:
with open(file, 'r') as f:
text = f.readline()
char_dict += text
char_dict = set(char_dict)
char_dict = list(char_dict)
dict_size = len(char_dict)
print("dict_size:{}".format(dict_size))
return char_dict
def get_data(path, char_dict):
train_x = []
train_y = []
txt_files = glob.glob(path + '*.txt')
# random.shuffle(txt_files)
for file in txt_files:
base_name = os.path.basename(file)
file_name, _ = os.path.splitext(base_name)
image = cv2.imread(path + file_name + '.jpg', cv2.IMREAD_GRAYSCALE)
train_x.append(image)
with open(file, 'r') as f:
label = []
text = f.readline()
for c in text:
index = char_dict.index(c)
label.append(index)
train_y.append(label)
# # 若图片路径中含中文字符时,
# # cv2.imread()读取图像失败返回None,
# # 删除为None的数据
# for i, img in enumerate(train_x):
# if train_x[i] is None:
# del train_x[i]
# del train_y[i]
print("train_size:{}".format(len(train_x)))
return train_x, train_y
reference:
glob.glob()
返回path路径下的符合条件的所有文件,然后用for循环对每一个文件进行操作。
import glob
txt_files = glob.glob(path + '*.txt')
for file in txt_files:
***
python open()
用于打开一个文件。创建一个 file 对象,相关的方法才可以调用它进行读写。
with open(file, 'r') as f:
text = f.readline()
readline()函数读取整行,包括 "\n" 字符。
python set()
返回一个无序不重复元素集(这里用于删除重复的中文字符)。
python list()
用于将元组转换为列表。