本次验证的目的在于:使用tf.feature_column.embedding_column函数进行训练得到的vector,在预测时直接使用tf.feature_column.embedding_column(k,dimension=8,ckpt_to_load_from=',tensor_name_in_ckpt=embedding_weights")来调用时,能否将样本对应到相应的hash bucket中。
首先,将category column进行tf.feature_column.categorical_column_with_hash_bucket处理,最后得到的hash_bucket即为分hash bucket后,每个样本对应的hash桶id。
然后,在预测时,直接调用tf.feature_column.embedding_column(k,dimension=8,ckpt_to_load_from=',tensor_name_in_ckpt=embedding_weights"),对样本进行处理。这里取id=45,发现调用后id为45的样本为第0,41,433个,取sess.run()后的数据查看,发现进过训练得到的embedding向量完全一致!
经过这一系列的验证,发现直接tf.feature_column.embedding_column很方便,且能保证训练和预测一致。