Repository: incubator-joshua Updated Branches: refs/heads/7 85aa03232 -> 8ba242ec8
Refactoring of ComputeNodeResult into a static method. Project: http://git-wip-us.apache.org/repos/asf/incubator-joshua/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-joshua/commit/dd82f081 Tree: http://git-wip-us.apache.org/repos/asf/incubator-joshua/tree/dd82f081 Diff: http://git-wip-us.apache.org/repos/asf/incubator-joshua/diff/dd82f081 Branch: refs/heads/7 Commit: dd82f081b9519764237b6fbbab3e84a397b9594c Parents: ff9af37 Author: Tobias Domhan <domh...@amazon.com> Authored: Wed Sep 14 15:12:54 2016 +0200 Committer: Tobias Domhan <domh...@amazon.com> Committed: Wed Sep 14 15:12:54 2016 +0200 ---------------------------------------------------------------------- .../joshua/decoder/chart_parser/Cell.java | 4 +- .../joshua/decoder/chart_parser/Chart.java | 16 +- .../decoder/chart_parser/ComputeNodeResult.java | 288 ++++++++----------- .../decoder/chart_parser/CubePruneState.java | 16 +- .../joshua/decoder/chart_parser/NodeResult.java | 86 ++++++ .../apache/joshua/decoder/phrase/Candidate.java | 9 +- .../joshua/decoder/phrase/PhraseChart.java | 5 +- .../apache/joshua/decoder/phrase/Stacks.java | 4 +- 8 files changed, 236 insertions(+), 192 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/dd82f081/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/Cell.java ---------------------------------------------------------------------- diff --git a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/Cell.java b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/Cell.java index a771bec..6eff27f 100644 --- a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/Cell.java +++ b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/Cell.java @@ -159,8 +159,8 @@ class Cell { * * @return the new hypernode, or null if the cell was pruned. */ - HGNode addHyperEdgeInCell(ComputeNodeResult result, Rule rule, int i, int j, List<HGNode> ants, - SourcePath srcPath, boolean noPrune) { + HGNode addHyperEdgeInCell(NodeResult result, Rule rule, int i, int j, List<HGNode> ants, + SourcePath srcPath, boolean noPrune) { // System.err.println(String.format("ADD_EDGE(%d-%d): %s", i, j, rule.getRuleString())); // if (ants != null) { http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/dd82f081/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/Chart.java ---------------------------------------------------------------------- diff --git a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/Chart.java b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/Chart.java index bd91a6f..a1c3093 100644 --- a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/Chart.java +++ b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/Chart.java @@ -46,6 +46,8 @@ import org.apache.joshua.util.ChartSpan; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.apache.joshua.decoder.chart_parser.ComputeNodeResult.computeNodeResult; + /** * Chart class this class implements chart-parsing: (1) seeding the chart (2) * cky main loop over bins, (3) identify applicable rules in each bin @@ -241,7 +243,7 @@ public class Chart { break; } - ComputeNodeResult result = new ComputeNodeResult(this.featureFunctions, rule, null, i, + NodeResult result = computeNodeResult(this.featureFunctions, rule, null, i, j, sourcePath, this.sentence); if (stateConstraint == null || stateConstraint.isLegal(result.getDPStates())) { @@ -270,7 +272,7 @@ public class Chart { int[] ranks = new int[1 + superNodes.size()]; Arrays.fill(ranks, 1); - ComputeNodeResult result = new ComputeNodeResult(featureFunctions, bestRule, + NodeResult result = computeNodeResult(featureFunctions, bestRule, currentTailNodes, i, j, sourcePath, sentence); CubePruneState bestState = new CubePruneState(result, ranks, rules, currentTailNodes, dotNode); @@ -319,7 +321,7 @@ public class Chart { * doing constrained decoding or (b) we are and the state is legal. */ if (stateConstraint == null || stateConstraint.isLegal(state.getDPStates())) { - getCell(i, j).addHyperEdgeInCell(state.computeNodeResult, state.getRule(), i, j, + getCell(i, j).addHyperEdgeInCell(state.nodeResult, state.getRule(), i, j, state.antNodes, sourcePath, true); } @@ -354,7 +356,7 @@ public class Chart { nextAntNodes.add(superNodes.get(x).nodes.get(nextRanks[x + 1] - 1)); /* Create the next state. */ - CubePruneState nextState = new CubePruneState(new ComputeNodeResult(featureFunctions, + CubePruneState nextState = new CubePruneState(computeNodeResult(featureFunctions, nextRule, nextAntNodes, i, j, sourcePath, this.sentence), nextRanks, rules, nextAntNodes, dotNode); @@ -538,7 +540,7 @@ public class Chart { int[] ranks = new int[1 + superNodes.size()]; Arrays.fill(ranks, 1); - ComputeNodeResult result = new ComputeNodeResult(featureFunctions, bestRule, tailNodes, + NodeResult result = computeNodeResult(featureFunctions, bestRule, tailNodes, dotNode.begin(), dotNode.end(), dotNode.getSourcePath(), sentence); CubePruneState seedState = new CubePruneState(result, ranks, rules, tailNodes, dotNode); @@ -704,7 +706,7 @@ public class Chart { List<Rule> rules = childNode.getRuleCollection().getSortedRules(this.featureFunctions); for (Rule rule : rules) { // for each unary rules - ComputeNodeResult states = new ComputeNodeResult(this.featureFunctions, rule, + NodeResult states = computeNodeResult(this.featureFunctions, rule, antecedents, i, j, new SourcePath(), this.sentence); HGNode resNode = chartBin.addHyperEdgeInCell(states, rule, i, j, antecedents, new SourcePath(), true); @@ -736,7 +738,7 @@ public class Chart { } this.cells.get(i, j).addHyperEdgeInCell( - new ComputeNodeResult(this.featureFunctions, rule, null, i, j, srcPath, sentence), rule, i, + computeNodeResult(this.featureFunctions, rule, null, i, j, srcPath, sentence), rule, i, j, null, srcPath, false); } http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/dd82f081/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/ComputeNodeResult.java ---------------------------------------------------------------------- diff --git a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/ComputeNodeResult.java b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/ComputeNodeResult.java index 0e7cd6d..cfaf96a 100644 --- a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/ComputeNodeResult.java +++ b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/ComputeNodeResult.java @@ -18,10 +18,6 @@ */ package org.apache.joshua.decoder.chart_parser; -import static org.apache.joshua.decoder.ff.FeatureMap.hashFeature; - -import java.util.ArrayList; -import java.util.List; import org.apache.joshua.decoder.Decoder; import org.apache.joshua.decoder.ff.FeatureFunction; @@ -35,189 +31,141 @@ import org.apache.joshua.decoder.segment_file.Sentence; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * This class computes the cost of applying a rule. - * - * @author Matt Post p...@cs.jhu.edu - * @author Zhifei Li, zhifei.w...@gmail.com - */ +import java.util.ArrayList; +import java.util.List; + +import static org.apache.joshua.decoder.ff.FeatureMap.hashFeature; public class ComputeNodeResult { + private ComputeNodeResult() {}; + + private static final Logger LOG = LoggerFactory.getLogger(NodeResult.class); + + /** + * Computes the new state(s) that are produced when applying the given rule to the list of tail + * nodes. Also computes a range of costs of doing so (the transition cost, the total (Viterbi) + * cost, and a score that includes a future cost estimate). + * + * Old version that doesn't use the derivation state. + * @param featureFunctions {@link java.util.List} of {@link org.apache.joshua.decoder.ff.FeatureFunction}'s + * @param rule {@link org.apache.joshua.decoder.ff.tm.Rule} to use when computing th node result + * @param tailNodes {@link java.util.List} of {@link org.apache.joshua.decoder.hypergraph.HGNode}'s + * @param i todo + * @param j todo + * @param sourcePath information about a path taken through the source lattice + * @param sentence the lattice input + */ + public static NodeResult computeNodeResult(List<FeatureFunction> featureFunctions, Rule rule, List<HGNode> tailNodes, + int i, int j, SourcePath sourcePath, Sentence sentence) { - private static final Logger LOG = LoggerFactory.getLogger(ComputeNodeResult.class); - - // The cost incurred by the rule itself (and all associated feature functions) - private float transitionCost; - - // transitionCost + the Viterbi costs of the tail nodes. - private float viterbiCost; - - // The future or outside cost (estimated) - private float futureCostEstimate; - - // The StateComputer objects themselves serve as keys. - private final List<DPState> dpStates; - - /** - * Computes the new state(s) that are produced when applying the given rule to the list of tail - * nodes. Also computes a range of costs of doing so (the transition cost, the total (Viterbi) - * cost, and a score that includes a future cost estimate). - * - * Old version that doesn't use the derivation state. - * @param featureFunctions {@link java.util.List} of {@link org.apache.joshua.decoder.ff.FeatureFunction}'s - * @param rule {@link org.apache.joshua.decoder.ff.tm.Rule} to use when computing th node result - * @param tailNodes {@link java.util.List} of {@link org.apache.joshua.decoder.hypergraph.HGNode}'s - * @param i todo - * @param j todo - * @param sourcePath information about a path taken through the source lattice - * @param sentence the lattice input - */ - public ComputeNodeResult(List<FeatureFunction> featureFunctions, Rule rule, List<HGNode> tailNodes, - int i, int j, SourcePath sourcePath, Sentence sentence) { - - // The total Viterbi cost of this edge. This is the Viterbi cost of the tail nodes, plus - // whatever costs we incur applying this rule to create a new hyperedge. - this.viterbiCost = 0.0f; - - if (LOG.isDebugEnabled()) { - LOG.debug("ComputeNodeResult():"); - LOG.debug("-> RULE {}", rule); - } + // The total Viterbi cost of this edge. This is the Viterbi cost of the tail nodes, plus + // whatever costs we incur applying this rule to create a new hyperedge. + float viterbiCost = 0.0f; - /* - * Here we sum the accumulated cost of each of the tail nodes. The total cost of the new - * hyperedge (the inside or Viterbi cost) is the sum of these nodes plus the cost of the - * transition. Note that this could and should all be generalized to whatever semiring is being - * used. - */ - if (null != tailNodes) { - for (HGNode item : tailNodes) { - if (LOG.isDebugEnabled()) { - LOG.debug("-> item.bestedge: {}", item); - LOG.debug("-> TAIL NODE {}", item); + if (LOG.isDebugEnabled()) { + LOG.debug("NodeResult():"); + LOG.debug("-> RULE {}", rule); + } + + /* + * Here we sum the accumulated cost of each of the tail nodes. The total cost of the new + * hyperedge (the inside or Viterbi cost) is the sum of these nodes plus the cost of the + * transition. Note that this could and should all be generalized to whatever semiring is being + * used. + */ + if (null != tailNodes) { + for (HGNode item : tailNodes) { + if (LOG.isDebugEnabled()) { + LOG.debug("-> item.bestedge: {}", item); + LOG.debug("-> TAIL NODE {}", item); + } + viterbiCost += item.bestHyperedge.getBestDerivationScore(); } - viterbiCost += item.bestHyperedge.getBestDerivationScore(); } - } - List<DPState> allDPStates = new ArrayList<>(); + List<DPState> allDPStates = new ArrayList<>(); - // The transition cost is the new cost incurred by applying this rule - this.transitionCost = 0.0f; + // The transition cost is the new cost incurred by applying this rule + float transitionCost = 0.0f; - // The future cost estimate is a heuristic estimate of the outside cost of this edge. - this.futureCostEstimate = 0.0f; + // The future cost estimate is a heuristic estimate of the outside cost of this edge. + float futureCostEstimate = 0.0f; - /* - * We now iterate over all the feature functions, computing their cost and their expected future - * cost. - */ - for (FeatureFunction feature : featureFunctions) { - FeatureFunction.ScoreAccumulator acc = feature.new ScoreAccumulator(); + /* + * We now iterate over all the feature functions, computing their cost and their expected future + * cost. + */ + for (FeatureFunction feature : featureFunctions) { + FeatureFunction.ScoreAccumulator acc = feature.new ScoreAccumulator(); - DPState newState = feature.compute(rule, tailNodes, i, j, sourcePath, sentence, acc); - this.transitionCost += acc.getScore(); + DPState newState = feature.compute(rule, tailNodes, i, j, sourcePath, sentence, acc); + transitionCost += acc.getScore(); + if (LOG.isDebugEnabled()) { + LOG.debug("FEATURE {} = {} * {} = {}", feature.getName(), + acc.getScore() / Decoder.weights.getOrDefault(hashFeature(feature.getName())), + Decoder.weights.getOrDefault(hashFeature(feature.getName())), acc.getScore()); + } - if (LOG.isDebugEnabled()) { - LOG.debug("FEATURE {} = {} * {} = {}", feature.getName(), - acc.getScore() / Decoder.weights.getOrDefault(hashFeature(feature.getName())), - Decoder.weights.getOrDefault(hashFeature(feature.getName())), acc.getScore()); + if (feature.isStateful()) { + futureCostEstimate += feature.estimateFutureCost(rule, newState, sentence); + allDPStates.add(((StatefulFF)feature).getStateIndex(), newState); + } } + viterbiCost += transitionCost; + if (LOG.isDebugEnabled()) + LOG.debug("-> COST = {}", transitionCost); - if (feature.isStateful()) { - futureCostEstimate += feature.estimateFutureCost(rule, newState, sentence); - allDPStates.add(((StatefulFF)feature).getStateIndex(), newState); - } - } - this.viterbiCost += transitionCost; - if (LOG.isDebugEnabled()) - LOG.debug("-> COST = {}", transitionCost); - - this.dpStates = allDPStates; - } - - /** - * This is called from {@link org.apache.joshua.decoder.chart_parser.Cell} - * when making the final transition to the goal state. - * This is done to allow feature functions to correct for partial estimates, since - * they now have the knowledge that the whole sentence is complete. Basically, this - * is only used by LanguageModelFF, which does not score partial n-grams, and therefore - * needs to correct for this when a short sentence ends. KenLMFF corrects for this by - * always scoring partial hypotheses, and subtracting off the partial score when longer - * context is available. This would be good to do for the LanguageModelFF feature function, - * too: it makes search better (more accurate at the beginning, for example), and would - * also do away with the need for the computeFinal* class of functions (and hooks in - * the feature function interface). - * - * @param featureFunctions {@link java.util.List} of {@link org.apache.joshua.decoder.ff.FeatureFunction}'s - * @param tailNodes {@link java.util.List} of {@link org.apache.joshua.decoder.hypergraph.HGNode}'s - * @param i todo - * @param j todo - * @param sourcePath information about a path taken through the source lattice - * @param sentence the lattice input - * @return the final cost for the Node - */ - public static float computeFinalCost(List<FeatureFunction> featureFunctions, - List<HGNode> tailNodes, int i, int j, SourcePath sourcePath, Sentence sentence) { - - float cost = 0; - for (FeatureFunction ff : featureFunctions) { - cost += ff.computeFinalCost(tailNodes.get(0), i, j, sourcePath, sentence); + return new NodeResult(transitionCost, viterbiCost, futureCostEstimate, allDPStates); } - return cost; - } - - public static FeatureVector computeTransitionFeatures(List<FeatureFunction> featureFunctions, - HyperEdge edge, int i, int j, Sentence sentence) { - - // Initialize the set of features with those that were present with the rule in the grammar. - FeatureVector featureDelta = new FeatureVector(featureFunctions.size()); - - // === compute feature logPs - for (FeatureFunction ff : featureFunctions) { - // A null rule signifies the final transition. - if (edge.getRule() == null) - featureDelta.addInPlace(ff.computeFinalFeatures(edge.getTailNodes().get(0), i, j, edge.getSourcePath(), sentence)); - else { - featureDelta.addInPlace(ff.computeFeatures(edge.getRule(), edge.getTailNodes(), i, j, edge.getSourcePath(), sentence)); + + + /** + * This is called from {@link Cell} + * when making the final transition to the goal state. + * This is done to allow feature functions to correct for partial estimates, since + * they now have the knowledge that the whole sentence is complete. Basically, this + * is only used by LanguageModelFF, which does not score partial n-grams, and therefore + * needs to correct for this when a short sentence ends. KenLMFF corrects for this by + * always scoring partial hypotheses, and subtracting off the partial score when longer + * context is available. This would be good to do for the LanguageModelFF feature function, + * too: it makes search better (more accurate at the beginning, for example), and would + * also do away with the need for the computeFinal* class of functions (and hooks in + * the feature function interface). + * + * @param featureFunctions {@link List} of {@link FeatureFunction}'s + * @param tailNodes {@link List} of {@link HGNode}'s + * @param i todo + * @param j todo + * @param sourcePath information about a path taken through the source lattice + * @param sentence the lattice input + * @return the final cost for the Node + */ + public static float computeFinalCost(List<FeatureFunction> featureFunctions, + List<HGNode> tailNodes, int i, int j, SourcePath sourcePath, Sentence sentence) { + + float cost = 0; + for (FeatureFunction ff : featureFunctions) { + cost += ff.computeFinalCost(tailNodes.get(0), i, j, sourcePath, sentence); } + return cost; } - return featureDelta; - } - - public float getFutureEstimate() { - return this.futureCostEstimate; - } - - public float getPruningEstimate() { - return getViterbiCost() + getFutureEstimate(); - } - - /** - * The complete cost of the Viterbi derivation at this point. - * - * @return float representing cost - */ - public float getViterbiCost() { - return this.viterbiCost; - } - - public float getBaseCost() { - return getViterbiCost() - getTransitionCost(); - } - - /** - * The cost incurred by this edge alone - * - * @return float representing cost - */ - public float getTransitionCost() { - return this.transitionCost; - } - - public List<DPState> getDPStates() { - return this.dpStates; - } + public static FeatureVector computeTransitionFeatures(List<FeatureFunction> featureFunctions, + HyperEdge edge, int i, int j, Sentence sentence) { + // Initialize the set of features with those that were present with the rule in the grammar. + FeatureVector featureDelta = new FeatureVector(featureFunctions.size()); + + // === compute feature logPs + for (FeatureFunction ff : featureFunctions) { + // A null rule signifies the final transition. + if (edge.getRule() == null) + featureDelta.addInPlace(ff.computeFinalFeatures(edge.getTailNodes().get(0), i, j, edge.getSourcePath(), sentence)); + else { + featureDelta.addInPlace(ff.computeFeatures(edge.getRule(), edge.getTailNodes(), i, j, edge.getSourcePath(), sentence)); + } + } + + return featureDelta; + } } http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/dd82f081/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/CubePruneState.java ---------------------------------------------------------------------- diff --git a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/CubePruneState.java b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/CubePruneState.java index 1f06d30..0106ac5 100644 --- a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/CubePruneState.java +++ b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/CubePruneState.java @@ -31,13 +31,13 @@ import org.apache.joshua.decoder.ff.tm.Rule; // =============================================================== public class CubePruneState implements Comparable<CubePruneState> { final int[] ranks; - final ComputeNodeResult computeNodeResult; + final NodeResult nodeResult; final List<HGNode> antNodes; final List<Rule> rules; private DotNode dotNode; - public CubePruneState(ComputeNodeResult score, int[] ranks, List<Rule> rules, List<HGNode> antecedents, DotNode dotNode) { - this.computeNodeResult = score; + public CubePruneState(NodeResult score, int[] ranks, List<Rule> rules, List<HGNode> antecedents, DotNode dotNode) { + this.nodeResult = score; this.ranks = ranks; this.rules = rules; this.antNodes = antecedents; @@ -50,7 +50,7 @@ public class CubePruneState implements Comparable<CubePruneState> { * @return */ List<DPState> getDPStates() { - return this.computeNodeResult.getDPStates(); + return this.nodeResult.getDPStates(); } Rule getRule() { @@ -59,8 +59,8 @@ public class CubePruneState implements Comparable<CubePruneState> { public String toString() { String sb = "STATE ||| rule=" + getRule() + " inside cost = " + - computeNodeResult.getViterbiCost() + " estimate = " + - computeNodeResult.getPruningEstimate(); + nodeResult.getViterbiCost() + " estimate = " + + nodeResult.getPruningEstimate(); return sb; } @@ -99,10 +99,10 @@ public class CubePruneState implements Comparable<CubePruneState> { * order (high-prob first). */ public int compareTo(CubePruneState another) { - if (this.computeNodeResult.getPruningEstimate() < another.computeNodeResult + if (this.nodeResult.getPruningEstimate() < another.nodeResult .getPruningEstimate()) { return 1; - } else if (this.computeNodeResult.getPruningEstimate() == another.computeNodeResult + } else if (this.nodeResult.getPruningEstimate() == another.nodeResult .getPruningEstimate()) { return 0; } else { http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/dd82f081/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/NodeResult.java ---------------------------------------------------------------------- diff --git a/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/NodeResult.java b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/NodeResult.java new file mode 100644 index 0000000..ade2c01 --- /dev/null +++ b/joshua-core/src/main/java/org/apache/joshua/decoder/chart_parser/NodeResult.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.joshua.decoder.chart_parser; + +import java.util.List; + +import org.apache.joshua.decoder.ff.state_maintenance.DPState; + +/** + * This class represents the cost of applying a rule. + * + * @author Matt Post p...@cs.jhu.edu + * @author Zhifei Li, zhifei.w...@gmail.com + */ +public class NodeResult { + + // The cost incurred by the rule itself (and all associated feature functions) + private final float transitionCost; + + // transitionCost + the Viterbi costs of the tail nodes. + private final float viterbiCost; + + // The future or outside cost (estimated) + private final float futureCostEstimate; + + // The StateComputer objects themselves serve as keys. + private final List<DPState> dpStates; + + public NodeResult(float transitionCost, float viterbiCost, float futureCostEstimate, List<DPState> dpStates) { + this.transitionCost = transitionCost; + this.viterbiCost = viterbiCost; + this.futureCostEstimate = futureCostEstimate; + this.dpStates = dpStates; + } + + public float getFutureEstimate() { + return this.futureCostEstimate; + } + + public float getPruningEstimate() { + return getViterbiCost() + getFutureEstimate(); + } + + /** + * The complete cost of the Viterbi derivation at this point. + * + * @return float representing cost + */ + public float getViterbiCost() { + return this.viterbiCost; + } + + public float getBaseCost() { + return getViterbiCost() - getTransitionCost(); + } + + /** + * The cost incurred by this edge alone + * + * @return float representing cost + */ + public float getTransitionCost() { + return this.transitionCost; + } + + public List<DPState> getDPStates() { + return this.dpStates; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/dd82f081/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/Candidate.java ---------------------------------------------------------------------- diff --git a/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/Candidate.java b/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/Candidate.java index c5f96e5..9f0dec1 100644 --- a/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/Candidate.java +++ b/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/Candidate.java @@ -38,12 +38,15 @@ import java.util.List; import org.apache.joshua.corpus.Span; import org.apache.joshua.decoder.chart_parser.ComputeNodeResult; +import org.apache.joshua.decoder.chart_parser.NodeResult; import org.apache.joshua.decoder.ff.FeatureFunction; import org.apache.joshua.decoder.ff.state_maintenance.DPState; import org.apache.joshua.decoder.ff.tm.Rule; import org.apache.joshua.decoder.hypergraph.HGNode; import org.apache.joshua.decoder.segment_file.Sentence; +import static org.apache.joshua.decoder.chart_parser.ComputeNodeResult.computeNodeResult; + public class Candidate implements Comparable<Candidate> { private List<FeatureFunction> featureFunctions; @@ -68,7 +71,7 @@ public class Candidate implements Comparable<Candidate> { * Stores the inside cost of the current phrase, as well as the computed dynamic programming * state. Expensive to compute so there is an option of delaying it. */ - private ComputeNodeResult computedResult; + private NodeResult computedResult; /** * When candidate objects are extended, the new one is initialized with the same underlying @@ -225,11 +228,11 @@ public class Candidate implements Comparable<Candidate> { * * @return the computed result. */ - public ComputeNodeResult computeResult() { + public NodeResult computeResult() { if (computedResult == null) { // add the rule // TODO: sourcepath - computedResult = new ComputeNodeResult(featureFunctions, getRule(), getTailNodes(), getLastCovered(), getPhraseEnd(), null, sentence); + computedResult = computeNodeResult(featureFunctions, getRule(), getTailNodes(), getLastCovered(), getPhraseEnd(), null, sentence); } return computedResult; http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/dd82f081/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/PhraseChart.java ---------------------------------------------------------------------- diff --git a/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/PhraseChart.java b/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/PhraseChart.java index b7c3001..8355eb5 100644 --- a/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/PhraseChart.java +++ b/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/PhraseChart.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.List; import org.apache.joshua.decoder.chart_parser.ComputeNodeResult; +import org.apache.joshua.decoder.chart_parser.NodeResult; import org.apache.joshua.decoder.ff.FeatureFunction; import org.apache.joshua.decoder.ff.tm.Rule; import org.apache.joshua.decoder.ff.tm.RuleCollection; @@ -32,6 +33,8 @@ import org.apache.joshua.decoder.segment_file.Sentence; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.apache.joshua.decoder.chart_parser.ComputeNodeResult.computeNodeResult; + /** * This class represents a bundle of phrase tables that have been read in, * reporting some stats about them. Probably could be done away with. @@ -205,7 +208,7 @@ public class PhraseChart { // Turn each rule into an HGNode, add them one by one for (Rule rule: rules) { - ComputeNodeResult result = new ComputeNodeResult(features, rule, null, i, j, null, sentence); + NodeResult result = computeNodeResult(features, rule, null, i, j, null, sentence); HyperEdge edge = new HyperEdge(rule, result.getViterbiCost(), result.getTransitionCost(), null, null); HGNode phraseNode = new HGNode(i, j, rule.getLHS(), result.getDPStates(), edge, result.getPruningEstimate()); nodes.add(phraseNode); http://git-wip-us.apache.org/repos/asf/incubator-joshua/blob/dd82f081/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/Stacks.java ---------------------------------------------------------------------- diff --git a/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/Stacks.java b/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/Stacks.java index 230ed09..8fae284 100644 --- a/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/Stacks.java +++ b/joshua-core/src/main/java/org/apache/joshua/decoder/phrase/Stacks.java @@ -35,6 +35,7 @@ package org.apache.joshua.decoder.phrase; * TODO Lattice decoding is not yet supported. */ +import static org.apache.joshua.decoder.chart_parser.ComputeNodeResult.computeNodeResult; import static org.apache.joshua.decoder.ff.tm.OwnerMap.UNKNOWN_OWNER; import java.util.ArrayList; @@ -42,6 +43,7 @@ import java.util.List; import org.apache.joshua.decoder.JoshuaConfiguration; import org.apache.joshua.decoder.chart_parser.ComputeNodeResult; +import org.apache.joshua.decoder.chart_parser.NodeResult; import org.apache.joshua.decoder.ff.FeatureFunction; import org.apache.joshua.decoder.ff.tm.AbstractGrammar; import org.apache.joshua.decoder.ff.tm.Grammar; @@ -123,7 +125,7 @@ public class Stacks { stacks.add(null); // Initialize root hypothesis with <s> context and future cost for everything. - ComputeNodeResult result = new ComputeNodeResult(this.featureFunctions, Hypothesis.BEGIN_RULE, + NodeResult result = computeNodeResult(this.featureFunctions, Hypothesis.BEGIN_RULE, null, -1, 1, null, this.sentence); Stack firstStack = new Stack(sentence, config); firstStack.add(new Hypothesis(result.getDPStates(), future.Full()));