kaldi官方给出的解码命令是online2-wav-nnet3-latgen-faster
,它的源码我在“基于kaldi的iOS语音识别(本地)+05+解码”已经贴出来了,下面就来详细讲解它解码的过程,后面讲解实时流解码的时候再说一些改进和修改的地方。
首先我们来看它是怎么加载模型的:
TransitionModel trans_model;
nnet3::AmNnetSimple am_nnet;
{
bool binary;
Input ki(nnet3_rxfilename, &binary);
trans_model.Read(ki.Stream(), binary);
am_nnet.Read(ki.Stream(), binary);
SetBatchnormTestMode(true, &(am_nnet.GetNnet()));
SetDropoutTestMode(true, &(am_nnet.GetNnet()));
nnet3::CollapseModel(nnet3::CollapseModelConfig(), &(am_nnet.GetNnet()));
}
nnet3_rxfilename
就是我们给的final.mdl
模型文件
Nnet3中的AmNnetSimple
类是一个标准的声学模型类,该类通过调用Nnet类进行神经网络操作。
kaldi中的HMM模型,实际就是一个TransitionModel
对象。这个对象描述了音素的HMM拓扑结构,并保存了pdf-id和transition-id相关的信息,并且可以进行各种变量的转换。
这里先不对AmNnetSimple
和TransitionModel
类展开讲解,知道它是干嘛的就可以了。
nnet3::DecodableNnetSimpleLoopedInfo decodable_info(decodable_opts, &am_nnet);
此对象包含所有可解码对象使用的预先计算的内容。
fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldiGeneric(fst_rxfilename);
fst::Fst<fst::StdArc> *decode_fst = ReadFstKaldiGeneric(fst_rxfilename);
fst::SymbolTable *word_syms = NULL;
if (word_syms_rxfilename != "")
if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename)))
KALDI_ERR << "Could not read symbol table from file " << word_syms_rxfilename;
fst_rxfilename
对应HCLG.fst
文件。word_syms_rxfilename
对应words.txt
文件。
SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
CompactLatticeWriter clat_writer(clat_wspecifier);
这里分别创建spk2utt_reader
(说话人+音频),wav_reader
(待识别音频),clat_writer
(lattice写入)对象。
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
std::string spk = spk2utt_reader.Key();
const std::vector<std::string> &uttlist = spk2utt_reader.Value();
OnlineIvectorExtractorAdaptationState adaptation_state(
feature_info.ivector_extractor_info);
for (size_t i = 0; i < uttlist.size(); i++) {
std::string utt = uttlist[i];
if (!wav_reader.HasKey(utt)) {
KALDI_WARN << "Did not find audio for utterance " << utt;
num_err++;
continue;
}
}
...
}
该for循环,关于说话人音频的内容,我们先不关心,我们继续其他的部分。
const WaveData &wave_data = wav_reader.Value(utt);
// get the data for channel zero (if the signal is not mono, we only
// take the first channel).
SubVector<BaseFloat> data(wave_data.Data(), 0);
根据音频的id读取音频数据。
OnlineNnet2FeaturePipeline feature_pipeline(feature_info);
feature_pipeline.SetAdaptationState(adaptation_state);
OnlineSilenceWeighting silence_weighting(trans_model,
feature_info.silence_weighting_config,
decodable_opts.frame_subsampling_factor);
SingleUtteranceNnet3Decoder decoder(decoder_opts,
trans_model,
decodable_info,
*decode_fst, &feature_pipeline);
OnlineNnet2FeaturePipeline
负责将神经网络的特征处理管道的各个部分组合在一起。
OnlineSilenceWeighting
负责跟踪来自解码器的最佳路径回溯(有效的)并基于帧的分类计算数据的加权在静音(或非静音)。
SingleUtteranceNnet3Decoder
使用神经网络的在线配置来解码单个音频。
BaseFloat samp_freq = wave_data.SampFreq();
int32 chunk_length;
if (chunk_length_secs > 0) {
chunk_length = int32(samp_freq * chunk_length_secs);
if (chunk_length == 0) chunk_length = 1;
} else {
chunk_length = std::numeric_limits<int32>::max();
}
int32 samp_offset = 0;
std::vector<std::pair<int32, BaseFloat> > delta_weights;
while (samp_offset < data.Dim()) {
int32 samp_remaining = data.Dim() - samp_offset;
int32 num_samp = chunk_length < samp_remaining ? chunk_length
: samp_remaining;
SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
feature_pipeline.AcceptWaveform(samp_freq, wave_part);
samp_offset += num_samp;
decoding_timer.WaitUntil(samp_offset / samp_freq);
if (samp_offset == data.Dim()) {
// no more input. flush out last frames
feature_pipeline.InputFinished();
}
if (silence_weighting.Active() &&
feature_pipeline.IvectorFeature() != NULL) {
silence_weighting.ComputeCurrentTraceback(decoder.Decoder());
silence_weighting.GetDeltaWeights(feature_pipeline.NumFramesReady(),
&delta_weights);
feature_pipeline.IvectorFeature()->UpdateFrameWeights(delta_weights);
}
decoder.AdvanceDecoding();
if (do_endpointing && decoder.EndpointDetected(endpoint_opts)) {
break;
}
}
将音频一段一段给feature_pipeline
同时开启解码decoder.AdvanceDecoding();
decoder.FinalizeDecoding();
完成解码