- 可以将一些需要处理的文本文件处理一次后就使用
torch.save
(或者pickle
)存储成二进制文件方便下一次加载
def get_and_tokenize_dataset(tokenizer, dataset_dir='wikitext-103', dataset_cache=None, with_labels=False):
""" Retrieve, tokenize, encode and cache a dataset with optional labels """
if dataset_cache and os.path.isfile(dataset_cache):
logger.info("Load encoded dataset from cache at %s", dataset_cache)
encoded_dataset = torch.load(dataset_cache)
else:
# If the dataset is in our list of DATASETS_URL, use this url, otherwise, look for 'train.txt' and 'valid.txt' files
if dataset_dir in DATASETS_URL:
dataset_map = DATASETS_URL[dataset_dir]
else:
dataset_map = {'train': os.path.join(dataset_dir, 'train.txt'),
'valid': os.path.join(dataset_dir, 'valid.txt')}
logger.info("Get dataset from %s", dataset_dir)
# Download and read dataset and replace a few token for compatibility with the Bert tokenizer we are using
dataset = {}
for split_name in dataset_map.keys():
dataset_file = cached_path(dataset_map[split_name])
with open(dataset_file, "r", encoding="utf-8") as f:
all_lines = f.readlines()
dataset[split_name] = [
line.strip(' ').replace('<unk>', '[UNK]').replace('\n', '[SEP]' if not with_labels else '')
for line in tqdm(all_lines)]
# If we have labels, download and and convert labels in integers
labels = {}
if with_labels:
label_conversion_map = DATASETS_LABELS_CONVERSION[dataset_dir]
for split_name in DATASETS_LABELS_URL[dataset_dir]:
dataset_file = cached_path(dataset_map['labels'][split_name])
with open(dataset_file, "r", encoding="utf-8") as f:
all_lines = f.readlines()
labels[split_name] = [label_conversion_map[line.strip()] for line in tqdm(all_lines)]
# Tokenize and encode the dataset
logger.info("Tokenize and encode the dataset")
logging.getLogger("pytorch_pretrained_bert.tokenization").setLevel(logging.ERROR) # No warning on sample size
def encode(obj):
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
if isinstance(obj, dict):
return dict((n, encode(o)) for n, o in obj.items())
return list(encode(o) for o in tqdm(obj))
encoded_dataset = encode(dataset)
# Add labels if needed, or if we are doing language modeling, add number of words to get word-level ppl and gather in one list
for split_name in ['train', 'valid']:
if with_labels:
encoded_dataset[split_name + '_labels'] = labels[split_name]
else:
encoded_dataset[split_name] = [ind for line in encoded_dataset[split_name] for ind in line]
encoded_dataset[split_name + '_num_words'] = sum(len(line.split(' ')) for line in dataset[split_name])
# Save to cache
if dataset_cache:
logger.info("Save encoded dataset to cache at %s", dataset_cache)
torch.save(encoded_dataset, dataset_cache)
return encoded_dataset