- 先看reader.py,主要功能是读取数据,以及序列数据的生成
def _read_words(filename):
with tf.gfile.GFile(filename, "r") as f:
if sys.version_info[0] >= 3:
return f.read().replace("\n", "<eos>").split()
else:
return f.read().decode("utf-8").replace("\n", "<eos>").split()
tf.gfile.GFile 主要是用于HDFS等文件系统中的文件操作。详见
def _build_vocab(filename):
data = _read_words(filename)
counter = collections.Counter(data)
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*count_pairs))
word_to_id = dict(zip(words, range(len(words))))
return word_to_id
一个Counter是dict子类,用于计数可哈希的对象。这是一个无序的容器,元素被作为字典的key存储,它们的计数作为字典的value存储。详见。其中counter.items()是返回一个元素计数列表
c.items() # convert to a list of (elem, cnt) pairs
zip()函数使用*list/tuple的方式表示时,是将list/tuple分开,作为位置参数传递给对应函数(前提是对应函数支持不定个数的位置参数),函数效果可见。
而且此处sorted()函数就是按照计数降序排序,实现的时候直接将counter.items()的元素tuple前后换了个位置。
def ptb_producer(raw_data, batch_size, num_steps, name=None):
with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)
data_len = tf.size(raw_data)
batch_len = data_len // batch_size
data = tf.reshape(raw_data[0 : batch_size * batch_len],
[batch_size, batch_len])
epoch_size = (batch_len - 1) // num_steps
assertion = tf.assert_positive(
epoch_size,
message="epoch_size == 0, decrease batch_size or num_steps")
with tf.control_dependencies([assertion]):
epoch_size = tf.identity(epoch_size, name="epoch_size")
i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
x = tf.strided_slice(data, [0, i * num_steps],
[batch_size, (i + 1) * num_steps])
x.set_shape([batch_size, num_steps])
y = tf.strided_slice(data, [0, i * num_steps + 1],
[batch_size, (i + 1) * num_steps + 1])
y.set_shape([batch_size, num_steps])
return x, y
根据batch_size将原数据reshape(),划分了成了一个矩阵batch_size * batch_len维。(batch_len就是number of batches)后面代码解析见
- 有问题再接着写吧:)