Repository: incubator-joshua Updated Branches: refs/heads/master e7ead8fb3 -> 4c0b55337
Viterbi information is now extracted from the hypergraph using a more principled traversel functionality (WalkerFunction). Also updated the unit tests. Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/244e6936 Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/244e6936 Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/244e6936 Branch: refs/heads/master Commit: 244e6936d8e3e7b30ebbe49ff7a9a2bd0c0c9994 Parents: 9501535 Author: Felix Hieber <[email protected]> Authored: Mon Aug 24 08:29:17 2015 +0200 Committer: Kellen Sunderland <[email protected]> Committed: Thu Mar 31 10:44:42 2016 +0200 ---------------------------------------------------------------------- .../joshua/decoder/StructuredTranslation.java | 143 ++++++++++++++ .../ViterbiFeatureVectorWalkerFunction.java | 44 +++++ .../ViterbiOutputStringWalkerFunction.java | 96 ++++++++++ src/joshua/decoder/JoshuaConfiguration.java | 5 +- .../system/StructuredTranslationTest.java | 184 +++++++++++++++++++ 5 files changed, 470 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/244e6936/joshua-6/src/joshua/decoder/StructuredTranslation.java ---------------------------------------------------------------------- diff --git a/joshua-6/src/joshua/decoder/StructuredTranslation.java b/joshua-6/src/joshua/decoder/StructuredTranslation.java new file mode 100644 index 0000000..1939ea0 --- /dev/null +++ b/joshua-6/src/joshua/decoder/StructuredTranslation.java @@ -0,0 +1,143 @@ +package joshua.decoder; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static joshua.decoder.hypergraph.ViterbiExtractor.walk; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import joshua.decoder.ff.FeatureFunction; +import joshua.decoder.hypergraph.HyperGraph; +import joshua.decoder.hypergraph.ViterbiFeatureVectorWalkerFunction; +import joshua.decoder.hypergraph.ViterbiOutputStringWalkerFunction; +import joshua.decoder.hypergraph.WalkerFunction; +import joshua.decoder.hypergraph.WordAlignmentExtractor; +import joshua.decoder.segment_file.Sentence; + +/** + * structuredTranslation provides a more structured access to translation + * results than the Translation class. + * Members of instances of this class can be used upstream. + * <br/> + * TODO: + * Enable K-Best extraction. + * + * @author fhieber + */ +public class StructuredTranslation { + + private final Sentence sourceSentence; + private final List<FeatureFunction> featureFunctions; + + private final String translationString; + private final List<String> translationTokens; + private final float translationScore; + private List<List<Integer>> translationWordAlignments; + private Map<String,Float> translationFeatures; + private final float extractionTime; + + public StructuredTranslation(final Sentence sourceSentence, + final HyperGraph hypergraph, + final List<FeatureFunction> featureFunctions) { + + final long startTime = System.currentTimeMillis(); + + this.sourceSentence = sourceSentence; + this.featureFunctions = featureFunctions; + this.translationString = extractViterbiString(hypergraph); + this.translationTokens = extractTranslationTokens(); + this.translationScore = extractTranslationScore(hypergraph); + this.translationFeatures = extractViterbiFeatures(hypergraph); + this.translationWordAlignments = extractViterbiWordAlignment(hypergraph); + this.extractionTime = (System.currentTimeMillis() - startTime) / 1000.0f; + } + + private Map<String,Float> extractViterbiFeatures(final HyperGraph hypergraph) { + if (hypergraph == null) { + return emptyMap(); + } else { + ViterbiFeatureVectorWalkerFunction viterbiFeatureVectorWalker = new ViterbiFeatureVectorWalkerFunction(featureFunctions, sourceSentence); + walk(hypergraph.goalNode, viterbiFeatureVectorWalker); + return new HashMap<String,Float>(viterbiFeatureVectorWalker.getFeaturesMap()); + } + } + + private List<List<Integer>> extractViterbiWordAlignment(final HyperGraph hypergraph) { + if (hypergraph == null) { + return emptyList(); + } else { + final WordAlignmentExtractor wordAlignmentWalker = new WordAlignmentExtractor(); + walk(hypergraph.goalNode, wordAlignmentWalker); + return wordAlignmentWalker.getFinalWordAlignments(); + } + } + + private float extractTranslationScore(final HyperGraph hypergraph) { + if (hypergraph == null) { + return 0; + } else { + return hypergraph.goalNode.getScore(); + } + } + + private String extractViterbiString(final HyperGraph hypergraph) { + if (hypergraph == null) { + return sourceSentence.source(); + } else { + final WalkerFunction viterbiOutputStringWalker = new ViterbiOutputStringWalkerFunction(); + walk(hypergraph.goalNode, viterbiOutputStringWalker); + return viterbiOutputStringWalker.toString(); + } + } + + private List<String> extractTranslationTokens() { + if (translationString.isEmpty()) { + return emptyList(); + } else { + return asList(translationString.split("\\s+")); + } + } + + // Getters to use upstream + + public Sentence getSourceSentence() { + return sourceSentence; + } + + public int getSentenceId() { + return sourceSentence.id(); + } + + public String getTranslationString() { + return translationString; + } + + public List<String> getTranslationTokens() { + return translationTokens; + } + + public float getTranslationScore() { + return translationScore; + } + + /** + * Returns a list of target to source alignments. + */ + public List<List<Integer>> getTranslationWordAlignments() { + return translationWordAlignments; + } + + public Map<String,Float> getTranslationFeatures() { + return translationFeatures; + } + + /** + * Time taken to build output information from the hypergraph. + */ + public Float getExtractionTime() { + return extractionTime; + } +} http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/244e6936/joshua-6/src/joshua/decoder/hypergraph/ViterbiFeatureVectorWalkerFunction.java ---------------------------------------------------------------------- diff --git a/joshua-6/src/joshua/decoder/hypergraph/ViterbiFeatureVectorWalkerFunction.java b/joshua-6/src/joshua/decoder/hypergraph/ViterbiFeatureVectorWalkerFunction.java new file mode 100644 index 0000000..5af6c4d --- /dev/null +++ b/joshua-6/src/joshua/decoder/hypergraph/ViterbiFeatureVectorWalkerFunction.java @@ -0,0 +1,44 @@ +package joshua.decoder.hypergraph; + +import static joshua.decoder.chart_parser.ComputeNodeResult.computeTransitionFeatures; + +import java.util.List; +import java.util.Map; + +import joshua.decoder.ff.FeatureFunction; +import joshua.decoder.ff.FeatureVector; +import joshua.decoder.segment_file.Sentence; + +public class ViterbiFeatureVectorWalkerFunction implements WalkerFunction { + + private final FeatureVector features; + private final List<FeatureFunction> featureFunctions; + private final Sentence sourceSentence; + + public ViterbiFeatureVectorWalkerFunction( + final List<FeatureFunction> featureFunctions, + final Sentence sourceSentence) { + this.features = new FeatureVector(); + this.featureFunctions = featureFunctions; + this.sourceSentence = sourceSentence; + } + + /** + * Recompute feature values for each Viterbi edge and add to features. + */ + @Override + public void apply(HGNode node) { + final FeatureVector edgeFeatures = computeTransitionFeatures( + featureFunctions, node.bestHyperedge, node.i, node.j, sourceSentence); + features.add(edgeFeatures); + } + + public FeatureVector getFeatures() { + return features; + } + + public Map<String,Float> getFeaturesMap() { + return features.getMap(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/244e6936/joshua-6/src/joshua/decoder/hypergraph/ViterbiOutputStringWalkerFunction.java ---------------------------------------------------------------------- diff --git a/joshua-6/src/joshua/decoder/hypergraph/ViterbiOutputStringWalkerFunction.java b/joshua-6/src/joshua/decoder/hypergraph/ViterbiOutputStringWalkerFunction.java new file mode 100644 index 0000000..0c84375 --- /dev/null +++ b/joshua-6/src/joshua/decoder/hypergraph/ViterbiOutputStringWalkerFunction.java @@ -0,0 +1,96 @@ +package joshua.decoder.hypergraph; + +import static java.lang.Integer.MAX_VALUE; +import static joshua.corpus.Vocabulary.getWords; +import static joshua.corpus.Vocabulary.nt; + +import java.util.Stack; + +import joshua.decoder.ff.tm.Rule; + +public class ViterbiOutputStringWalkerFunction implements WalkerFunction { + + private Stack<int[]> viterbiWords = new Stack<int[]>(); + + @Override + public void apply(HGNode node) { + final Rule rule = node.bestHyperedge.getRule(); + if (rule != null) { + merge(rule.getEnglish()); + } + } + + private boolean containsNonTerminals(final int[] ids) { + boolean hasNonTerminals = false; + for (int i = 0; i < ids.length; i++) { + if (nt(ids[i])) { + hasNonTerminals = true; + break; + } + } + return hasNonTerminals; + } + + /** + * Returns the index of the next non-terminal slot to fill. + * Since non-terminals in right hand sides of rules are indexed by + * their order on the source side, this function looks for the largest + * negative id in ids and returns its index. + */ + private int getNextNonTerminalIndexToFill(final int[] ids) { + int nextIndex = 0; + int nextNonTerminal = -MAX_VALUE; + for (int i = 0; i < ids.length; i++) { + if (nt(ids[i]) && ids[i] > nextNonTerminal) { + nextIndex = i; + nextNonTerminal = ids[i]; + } + } + return nextIndex; + } + + private int[] substituteNonTerminal(final int[] parentWords, final int[] childWords) { + final int ntIndex = getNextNonTerminalIndexToFill(parentWords); + final int[] result = new int[parentWords.length + childWords.length - 1]; + int resultIndex = 0; + for (int i = 0; i < ntIndex; i++) { + result[resultIndex++] = parentWords[i]; + } + for (int i = 0; i < childWords.length; i++) { + result[resultIndex++] = childWords[i]; + } + for (int i = ntIndex + 1; i < parentWords.length; i++) { + result[resultIndex++] = parentWords[i]; + } + return result; + } + + private void merge(final int[] words) { + if (!containsNonTerminals(words) + && !viterbiWords.isEmpty() + && containsNonTerminals(viterbiWords.peek())) { + merge(substituteNonTerminal(viterbiWords.pop(), words)); + } else { + viterbiWords.add(words); + } + } + + @Override + public String toString() { + if (viterbiWords.isEmpty()) { + return ""; + } + + if (viterbiWords.size() != 1) { + throw new RuntimeException( + String.format( + "Stack of ViterbiOutputStringWalker should contain only a single (last) element, but was size %d", viterbiWords.size())); + } + + String result = getWords(viterbiWords.peek()); + // strip of sentence markers (<s>,</s>) + result = result.substring(result.indexOf(' ') + 1, result.lastIndexOf(' ')); + return result.trim(); + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/244e6936/src/joshua/decoder/JoshuaConfiguration.java ---------------------------------------------------------------------- diff --git a/src/joshua/decoder/JoshuaConfiguration.java b/src/joshua/decoder/JoshuaConfiguration.java index 266198c..2eb24c4 100644 --- a/src/joshua/decoder/JoshuaConfiguration.java +++ b/src/joshua/decoder/JoshuaConfiguration.java @@ -30,8 +30,9 @@ import joshua.util.io.LineReader; */ public class JoshuaConfiguration { - // whether to use structured output - public Boolean use_structured_output = false; + // whether to construct a StructuredTranslation object for each request instead of + // printing to stdout. Used when the Decoder is used from Java directly. + public Boolean construct_structured_output = false; // List of grammar files to read public ArrayList<String> tms = new ArrayList<String>(); http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/244e6936/tst/joshua/system/StructuredTranslationTest.java ---------------------------------------------------------------------- diff --git a/tst/joshua/system/StructuredTranslationTest.java b/tst/joshua/system/StructuredTranslationTest.java new file mode 100644 index 0000000..821ceea --- /dev/null +++ b/tst/joshua/system/StructuredTranslationTest.java @@ -0,0 +1,184 @@ +package joshua.system; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import joshua.corpus.Vocabulary; +import joshua.decoder.Decoder; +import joshua.decoder.JoshuaConfiguration; +import joshua.decoder.StructuredTranslation; +import joshua.decoder.Translation; +import joshua.decoder.segment_file.Sentence; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Integration test for the complete Joshua decoder using a toy grammar that translates + * a bunch of capital letters to lowercase letters. Rules in the test grammar + * drop and generate additional words and simulate reordering of rules, so that + * proper extraction of word alignments and other information from the decoder + * can be tested. + * + * @author fhieber + */ +public class StructuredTranslationTest { + + private JoshuaConfiguration joshuaConfig = null; + private Decoder decoder = null; + private static final String INPUT = "A K B1 U Z1 Z2 B2 C"; + private static final String EXPECTED_TRANSLATION = "a b n1 u z c1 k1 k2 k3 n1 n2 n3 c2"; + private static final List<String> EXPECTED_TRANSLATED_TOKENS = asList(EXPECTED_TRANSLATION.split("\\s+")); + private static final String EXPECTED_WORD_ALIGNMENT_STRING = "0-0 2-1 6-1 3-3 4-4 5-4 7-5 1-6 1-7 1-8 7-12"; + private static final List<List<Integer>> EXPECTED_WORD_ALIGNMENT = asList( + asList(0), asList(2, 6), asList(), asList(3), + asList(4, 5), asList(7), asList(1), + asList(1), asList(1), asList(), asList(), + asList(), asList(7)); + private static final double EXPECTED_SCORE = -17.0; + private static final Map<String,Float> EXPECTED_FEATURES = new HashMap<>(); + static { + EXPECTED_FEATURES.put("tm_glue_0", 1.0f); + EXPECTED_FEATURES.put("tm_pt_0", -3.0f); + EXPECTED_FEATURES.put("tm_pt_1", -3.0f); + EXPECTED_FEATURES.put("tm_pt_2", -3.0f); + EXPECTED_FEATURES.put("tm_pt_3", -3.0f); + EXPECTED_FEATURES.put("tm_pt_4", -3.0f); + EXPECTED_FEATURES.put("tm_pt_5", -3.0f); + EXPECTED_FEATURES.put("OOV", 7.0f); + } + + @Before + public void setUp() throws Exception { + Vocabulary.clear(); + joshuaConfig = new JoshuaConfiguration(); + joshuaConfig.search_algorithm = "cky"; + joshuaConfig.mark_oovs = false; + joshuaConfig.pop_limit = 100; + joshuaConfig.use_unique_nbest = false; + joshuaConfig.include_align_index = false; + joshuaConfig.topN = 0; + joshuaConfig.tms.add("thrax -owner pt -maxspan 20 -path resources/wa_grammar"); + joshuaConfig.tms.add("thrax -owner glue -maxspan -1 -path resources/grammar.glue"); + joshuaConfig.goal_symbol = "[GOAL]"; + joshuaConfig.default_non_terminal = "[X]"; + joshuaConfig.features.add("feature_function = OOVPenalty"); + joshuaConfig.weights.add("tm_pt_0 1"); + joshuaConfig.weights.add("tm_pt_1 1"); + joshuaConfig.weights.add("tm_pt_2 1"); + joshuaConfig.weights.add("tm_pt_3 1"); + joshuaConfig.weights.add("tm_pt_4 1"); + joshuaConfig.weights.add("tm_pt_5 1"); + joshuaConfig.weights.add("tm_glue_0 1"); + joshuaConfig.weights.add("OOVPenalty 2"); + decoder = new Decoder(joshuaConfig, ""); // second argument (configFile + // is not even used by the + // constructor/initialize) + } + + @After + public void tearDown() throws Exception { + Vocabulary.clear(); + decoder.cleanUp(); + decoder = null; + } + + private Translation decode(String input) { + Sentence sentence = new Sentence(input, 0, joshuaConfig); + return decoder.decode(sentence); + } + + @Test + public void givenInput_whenRegularOutputFormat_thenExpectedOutput() { + // GIVEN + joshuaConfig.construct_structured_output = false; + joshuaConfig.outputFormat = "%s | %a "; + + // WHEN + final String translation = decode(INPUT).toString().trim(); + + // THEN + assertEquals(EXPECTED_TRANSLATION + " | " + EXPECTED_WORD_ALIGNMENT_STRING, translation); + } + + @Test + public void givenInput_whenSaarStructuredOutputFormat_thenExpectedOutput() { + // GIVEN + joshuaConfig.construct_structured_output = true; + + // WHEN + final StructuredTranslation translation = decode(INPUT).getStructuredTranslation(); + final String translationString = translation.getTranslationString(); + final List<String> translatedTokens = translation.getTranslationTokens(); + final float translationScore = translation.getTranslationScore(); + final List<List<Integer>> wordAlignment = translation.getTranslationWordAlignments(); + final Map<String,Float> translationFeatures = translation.getTranslationFeatures(); + + // THEN + assertEquals(EXPECTED_TRANSLATION, translationString); + assertEquals(EXPECTED_TRANSLATED_TOKENS, translatedTokens); + assertEquals(EXPECTED_SCORE, translationScore, 0.00001); + assertEquals(EXPECTED_WORD_ALIGNMENT, wordAlignment); + assertEquals(wordAlignment.size(), translatedTokens.size()); + assertEquals(EXPECTED_FEATURES.entrySet(), translationFeatures.entrySet()); + } + + @Test + public void givenEmptyInput_whenSaarStructuredOutputFormat_thenEmptyOutput() { + // GIVEN + joshuaConfig.construct_structured_output = true; + + // WHEN + final StructuredTranslation translation = decode("").getStructuredTranslation(); + final String translationString = translation.getTranslationString(); + final List<String> translatedTokens = translation.getTranslationTokens(); + final float translationScore = translation.getTranslationScore(); + final List<List<Integer>> wordAlignment = translation.getTranslationWordAlignments(); + + // THEN + assertEquals("", translationString); + assertTrue(translatedTokens.isEmpty()); + assertEquals(0, translationScore, 0.00001); + assertTrue(wordAlignment.isEmpty()); + } + + @Test + public void givenOOVInput_whenSaarStructuredOutputFormat_thenOOVOutput() { + // GIVEN + joshuaConfig.construct_structured_output = true; + final String input = "gabarbl"; + + // WHEN + final StructuredTranslation translation = decode(input).getStructuredTranslation(); + final String translationString = translation.getTranslationString(); + final List<String> translatedTokens = translation.getTranslationTokens(); + final float translationScore = translation.getTranslationScore(); + final List<List<Integer>> wordAlignment = translation.getTranslationWordAlignments(); + + // THEN + assertEquals(input, translationString); + assertTrue(translatedTokens.contains(input)); + assertEquals(-199.0, translationScore, 0.00001); + assertTrue(wordAlignment.contains(asList(0))); + } + + @Test + public void givenEmptyInput_whenRegularOutputFormat_thenNewlineOutput() { + // GIVEN + joshuaConfig.construct_structured_output = false; + + // WHEN + final Translation translation = decode(""); + final String translationString = translation.toString(); + + // THEN + assertEquals("\n", translationString); + } + +}
