// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Data structures for storing and looking up the actual feature weights. #ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_ #define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fst { template class FeatureGroup; // Forward declarations template class LinearFstDataBuilder; // Immutable data storage of the feature weights in a linear // model. Produces state tuples that represent internal states of a // LinearTaggerFst. Object of this class can only be constructed via // either `LinearFstDataBuilder::Dump()` or `LinearFstData::Read()` // and usually used as refcount'd object shared across mutiple // `LinearTaggerFst` copies. // // TODO(wuke): more efficient trie implementation template class LinearFstData { public: friend class LinearFstDataBuilder; // For builder access typedef typename A::Label Label; typedef typename A::Weight Weight; // Sentence boundary labels. Both of them are negative labels other // than `kNoLabel`. static constexpr Label kStartOfSentence = -3; static constexpr Label kEndOfSentence = -2; // Constructs empty data; for non-trivial ways of construction see // `Read()` and `LinearFstDataBuilder`. LinearFstData() : max_future_size_(0), max_input_label_(1), input_attribs_(1) {} // Appends the state tuple of the start state to `output`, where // each tuple holds the node ids of a trie for each feature group. void EncodeStartState(std::vector *Read(std::istream &strm); std::ostream &Write(std::ostream &strm) const; private: // Offsets in `output_pool_` struct InputAttribute { size_t output_begin, output_length; std::istream &Read(std::istream &strm); std::ostream &Write(std::ostream &strm) const; }; // Mapping from input label to per-group feature label class GroupFeatureMap; // Translates the input label into input feature label of group // `group`; returns `kNoLabel` when there is no feature for that // group. Label FindFeature(size_t group, Label word) const; size_t max_future_size_; Label max_input_label_; std::vector>> groups_; std::vector input_attribs_; std::vector::TakeTransition(Iterator buffer_end, Iterator trie_state_begin, Iterator trie_state_end, Label ilabel, Label olabel, std::vector::GroupTransition(int group_id, int trie_state, Label ilabel, Label olabel, Weight *weight) const { Label group_ilabel = FindFeature(group_id, ilabel); return groups_[group_id]->Walk(trie_state, group_ilabel, olabel, weight); } template template inline typename A::Weight LinearFstData::FinalWeight( Iterator trie_state_begin, Iterator trie_state_end) const { DCHECK_EQ(trie_state_end - trie_state_begin, groups_.size()); size_t group_id = 0; Weight accum = Weight::One(); for (Iterator it = trie_state_begin; it != trie_state_end; ++it, ++group_id) accum = Times(accum, GroupFinalWeight(group_id, *it)); return accum; } template inline std::pair::const_iterator, typename std::vector::const_iterator> LinearFstData::PossibleOutputLabels(Label word) const { const InputAttribute &attrib = input_attribs_[word]; if (attrib.output_length == 0) return std::make_pair(output_set_.begin(), output_set_.end()); else return std::make_pair( output_pool_.begin() + attrib.output_begin, output_pool_.begin() + attrib.output_begin + attrib.output_length); } template inline LinearFstData *LinearFstData::Read(std::istream &strm) { std::unique_ptr> data(new LinearFstData()); ReadType(strm, &(data->max_future_size_)); ReadType(strm, &(data->max_input_label_)); // Feature groups size_t num_groups = 0; ReadType(strm, &num_groups); data->groups_.resize(num_groups); for (size_t i = 0; i < num_groups; ++i) data->groups_[i].reset(FeatureGroup::Read(strm)); // Other data ReadType(strm, &(data->input_attribs_)); ReadType(strm, &(data->output_pool_)); ReadType(strm, &(data->output_set_)); ReadType(strm, &(data->group_feat_map_)); if (strm) { return data.release(); } else { return nullptr; } } template inline std::ostream &LinearFstData::Write(std::ostream &strm) const { WriteType(strm, max_future_size_); WriteType(strm, max_input_label_); // Feature groups WriteType(strm, groups_.size()); for (size_t i = 0; i < groups_.size(); ++i) { groups_[i]->Write(strm); } // Other data WriteType(strm, input_attribs_); WriteType(strm, output_pool_); WriteType(strm, output_set_); WriteType(strm, group_feat_map_); return strm; } template typename A::Label LinearFstData::FindFeature(size_t group, Label word) const { DCHECK(word > 0 || word == kStartOfSentence || word == kEndOfSentence); if (word == kStartOfSentence || word == kEndOfSentence) return word; else return group_feat_map_.Find(group, word); } template inline std::istream &LinearFstData::InputAttribute::Read( std::istream &strm) { ReadType(strm, &output_begin); ReadType(strm, &output_length); return strm; } template inline std::ostream &LinearFstData::InputAttribute::Write( std::ostream &strm) const { WriteType(strm, output_begin); WriteType(strm, output_length); return strm; } // Forward declaration template class FeatureGroupBuilder; // An immutable grouping of features with similar context shape. Like // `LinearFstData`, this can only be constructed via `Read()` or // via its builder. // // Internally it uses a trie to store all feature n-grams and their // weights. The label of a trie edge is a pair (feat, olabel) of // labels. They can be either positive (ordinary label), `kNoLabel`, // `kStartOfSentence`, or `kEndOfSentence`. `kNoLabel` usually means // matching anything, with one exception: from the root of the trie, // there is a special (kNoLabel, kNoLabel) that leads to the implicit // start-of-sentence state. This edge is never actually matched // (`FindFirstMatch()` ensures this). template class FeatureGroup { public: friend class FeatureGroupBuilder; // for builder access typedef typename A::Label Label; typedef typename A::Weight Weight; int Start() const { return start_; } // Finds destination node from `cur` by consuming `ilabel` and // `olabel`. The transition weight is multiplied onto `weight`. int Walk(int cur, Label ilabel, Label olabel, Weight *weight) const; // Returns the final weight of the current trie state. Only valid if // the state is already known to be part of a final state (see // `LinearFstData<>::CanBeFinal()`). Weight FinalWeight(int trie_state) const { return trie_[trie_state].final_weight; } static FeatureGroup *Read(std::istream &strm) { size_t delay; ReadType(strm, &delay); int start; ReadType(strm, &start); Trie trie; ReadType(strm, &trie); std::unique_ptr> ret(new FeatureGroup(delay, start)); ret->trie_.swap(trie); ReadType(strm, &ret->next_state_); if (strm) { return ret.release(); } else { return nullptr; } } std::ostream &Write(std::ostream &strm) const { WriteType(strm, delay_); WriteType(strm, start_); WriteType(strm, trie_); WriteType(strm, next_state_); return strm; } size_t Delay() const { return delay_; } std::string Stats() const; private: // Label along the arcs on the trie. `kNoLabel` means anything // (non-negative label) can match; both sides holding `kNoLabel` // is not allow; otherwise the label is > 0 (enforced by // `LinearFstDataBuilder::AddWeight()`). struct InputOutputLabel; struct InputOutputLabelHash; // Data to be stored on the trie struct WeightBackLink { int back_link; Weight weight, final_weight; WeightBackLink() : back_link(kNoTrieNodeId), weight(Weight::One()), final_weight(Weight::One()) {} std::istream &Read(std::istream &strm) { ReadType(strm, &back_link); ReadType(strm, &weight); ReadType(strm, &final_weight); return strm; } std::ostream &Write(std::ostream &strm) const { WriteType(strm, back_link); WriteType(strm, weight); WriteType(strm, final_weight); return strm; } }; typedef FlatTrieTopology Topology; typedef MutableTrie Trie; explicit FeatureGroup(size_t delay, int start) : delay_(delay), start_(start) {} // Finds the first node with an arc with `label` following the // back-off chain of `parent`. Returns the node index or // `kNoTrieNodeId` when not found. int FindFirstMatch(InputOutputLabel label, int parent) const; size_t delay_; int start_; Trie trie_; // Where to go after hitting this state. When we reach a state with // no child and with no additional final weight (i.e. its final // weight is the same as its back-off), we can immediately go to its // back-off state. std::vector next_state_; FeatureGroup(const FeatureGroup &) = delete; FeatureGroup &operator=(const FeatureGroup &) = delete; }; template struct FeatureGroup::InputOutputLabel { Label input, output; explicit InputOutputLabel(Label i = kNoLabel, Label o = kNoLabel) : input(i), output(o) {} bool operator==(InputOutputLabel that) const { return input == that.input && output == that.output; } std::istream &Read(std::istream &strm) { ReadType(strm, &input); ReadType(strm, &output); return strm; } std::ostream &Write(std::ostream &strm) const { WriteType(strm, input); WriteType(strm, output); return strm; } }; template struct FeatureGroup::InputOutputLabelHash { size_t operator()(InputOutputLabel label) const { return static_cast(label.input * 7853 + label.output); } }; template int FeatureGroup::Walk(int cur, Label ilabel, Label olabel, Weight *weight) const { // Note: user of this method need to ensure `ilabel` and `olabel` // are valid (e.g. see DCHECKs in // `LinearFstData<>::TakeTransition()` and // `LinearFstData<>::FindFeature()`). int next; if (ilabel == LinearFstData::kStartOfSentence) { // An observed start-of-sentence only occurs in the beginning of // the input, when this feature group is delayed (i.e. there is // another feature group with a larger future size). The actual // input hasn't arrived so stay at the start state. DCHECK_EQ(cur, start_); next = start_; } else { // First, try exact match next = FindFirstMatch(InputOutputLabel(ilabel, olabel), cur); // Then try with don't cares if (next == kNoTrieNodeId) next = FindFirstMatch(InputOutputLabel(ilabel, kNoLabel), cur); if (next == kNoTrieNodeId) next = FindFirstMatch(InputOutputLabel(kNoLabel, olabel), cur); // All failed, go to empty context if (next == kNoTrieNodeId) next = trie_.Root(); *weight = Times(*weight, trie_[next].weight); next = next_state_[next]; } return next; } template inline int FeatureGroup::FindFirstMatch(InputOutputLabel label, int parent) const { if (label.input == kNoLabel && label.output == kNoLabel) return kNoTrieNodeId; // very important; see class doc. for (; parent != kNoTrieNodeId; parent = trie_[parent].back_link) { int next = trie_.Find(parent, label); if (next != kNoTrieNodeId) return next; } return kNoTrieNodeId; } template inline std::string FeatureGroup::Stats() const { std::ostringstream strm; int num_states = 2; for (int i = 2; i < next_state_.size(); ++i) num_states += i == next_state_[i]; strm << trie_.NumNodes() << " node(s); " << num_states << " state(s)"; return strm.str(); } template class LinearFstData::GroupFeatureMap { public: GroupFeatureMap() = default; void Init(size_t num_groups, size_t num_words) { num_groups_ = num_groups; pool_.clear(); pool_.resize(num_groups * num_words, kNoLabel); } Label Find(size_t group_id, Label ilabel) const { return pool_[IndexOf(group_id, ilabel)]; } bool Set(size_t group_id, Label ilabel, Label feat) { size_t i = IndexOf(group_id, ilabel); if (pool_[i] != kNoLabel && pool_[i] != feat) { FSTERROR() << "Feature group " << group_id << " already has feature for word " << ilabel; return false; } pool_[i] = feat; return true; } std::istream &Read(std::istream &strm) { ReadType(strm, &num_groups_); ReadType(strm, &pool_); return strm; } std::ostream &Write(std::ostream &strm) const { WriteType(strm, num_groups_); WriteType(strm, pool_); return strm; } private: size_t IndexOf(size_t group_id, Label ilabel) const { return ilabel * num_groups_ + group_id; } size_t num_groups_; // `pool_[ilabel * num_groups_ + group_id]` is the feature active // for group `group_id` with input `ilabel` std::vector