// Copyright 2005-2024 Google LLC // // Licensed under the Apache License, Version 2.0 (the 'License'); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an 'AS IS' BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. #ifndef FST_EXTENSIONS_LINEAR_LOGLINEAR_APPLY_H_ #define FST_EXTENSIONS_LINEAR_LOGLINEAR_APPLY_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include namespace fst { // Applies a FST model as a discriminative model to weighted input // `ifst`. `A` is an arc type with tropical weight of all the // input/output FSTs. // // In general, consider `ifst` an unnormalized probability // distribution between its input X and output Y, P(X, Y); and `lfst` // a group of unnormalized probability distributions of all its output // Z for every input Y, Q(Z|Y). `normalize` controls whether Q is // normalized for every Y before chaining with P(X, Y). I.e., for a // path (X, Y, Z) in `ofst` (where Y is hidden), // // - When `normalize` is true, its weight is P(X, Y) Q(Z|Y) / sum_z Q(z|Y); // - When `normalize` is false, its weight is P(X, Y) Q(Z|Y). template void LogLinearApply(const Fst &ifst, const Fst &lfst, MutableFst *ofst, bool normalize = true) { LogLinearApply(ifst, lfst, ofst, normalize); } // This version gives finer control over the arc type (`B`) to be used // in normalization. `B` is an arc type with log weight (e.g. `LogArc` // or `Log64Arc`). template void LogLinearApply(const Fst &ifst, const Fst &lfst, MutableFst *ofst, bool normalize = true) { if (normalize) { VectorFst unnormalized_ofst, rescored_ifsa; Compose(ifst, lfst, &unnormalized_ofst); { VectorFst tropical_ifsa(unnormalized_ofst); Project(&tropical_ifsa, ProjectType::INPUT); { VectorFst minimal_log_ifsa; { VectorFst log_ifsa; ArcMap(tropical_ifsa, &log_ifsa, WeightConvertMapper()); RmEpsilon(&log_ifsa); Determinize(log_ifsa, &minimal_log_ifsa); } Minimize(&minimal_log_ifsa); ArcMap(&minimal_log_ifsa, InvertWeightMapper()); ArcMap(minimal_log_ifsa, &tropical_ifsa, WeightConvertMapper()); } ArcSort(&tropical_ifsa, OLabelCompare()); Compose(tropical_ifsa, ifst, &rescored_ifsa); } ArcSort(&rescored_ifsa, OLabelCompare()); Compose(rescored_ifsa, unnormalized_ofst, ofst); } else { Compose(ifst, lfst, ofst); } } } // namespace fst #endif // FST_EXTENSIONS_LINEAR_LOGLINEAR_APPLY_H_