This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 47ea25f624c178e155371637d6b2528064a9a88f Author: Matthias Boehm <[email protected]> AuthorDate: Fri Aug 5 14:12:00 2022 +0200 [MINOR] Fix warnings, imports, and formatting issues --- .../java/org/apache/sysds/api/PythonDMLScript.java | 1 - .../matrix/data/LibMatrixCountDistinct.java | 7 +- .../transform/tokenize/DocumentRepresentation.java | 41 ++- .../sysds/runtime/transform/tokenize/Token.java | 184 +++++----- .../runtime/transform/tokenize/Tokenizer.java | 385 ++++++++++----------- .../transform/tokenize/TokenizerFactory.java | 169 +++++---- .../tokenize/applier/TokenizerApplier.java | 343 +++++++++--------- .../tokenize/applier/TokenizerApplierCount.java | 168 +++++---- .../tokenize/applier/TokenizerApplierHash.java | 6 - .../tokenize/applier/TokenizerApplierPosition.java | 2 - .../tokenize/builder/TokenizerBuilder.java | 100 +++--- .../tokenize/builder/TokenizerBuilderNgram.java | 120 +++---- .../builder/TokenizerBuilderWhitespaceSplit.java | 90 +++-- .../builtin/part2/BuiltinTomeklinkTest.java | 1 - .../transform/TokenizeMultithreadedTest.java | 298 ++++++++-------- 15 files changed, 935 insertions(+), 980 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/PythonDMLScript.java b/src/main/java/org/apache/sysds/api/PythonDMLScript.java index 03ff025892..61a2ca823f 100644 --- a/src/main/java/org/apache/sysds/api/PythonDMLScript.java +++ b/src/main/java/org/apache/sysds/api/PythonDMLScript.java @@ -22,7 +22,6 @@ package org.apache.sysds.api; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.jmlc.Connection; -import org.apache.sysds.conf.CompilerConfig; import py4j.GatewayServer; import py4j.GatewayServerListener; diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java index 1198b18dd5..814b7737f9 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java @@ -19,7 +19,10 @@ package org.apache.sysds.runtime.matrix.data; -import java.util.*; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; import org.apache.commons.lang.NotImplementedException; import org.apache.commons.logging.Log; @@ -33,8 +36,6 @@ import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator; import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes; import org.apache.sysds.utils.Hash.HashType; -import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; - /** * This class contains various methods for counting the number of distinct values inside a MatrixBlock */ diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/DocumentRepresentation.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/DocumentRepresentation.java index b52ef34b46..af7ef50c6a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/DocumentRepresentation.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/DocumentRepresentation.java @@ -21,31 +21,30 @@ package org.apache.sysds.runtime.transform.tokenize; import java.util.ArrayList; import java.util.List; -import java.util.stream.Collectors; public class DocumentRepresentation { - public List<Object> keys; - public List<Token> tokens; + public List<Object> keys; + public List<Token> tokens; - public DocumentRepresentation(List<Object> keys, List<Token> tokens) { - this.keys = keys; - this.tokens = tokens; - } + public DocumentRepresentation(List<Object> keys, List<Token> tokens) { + this.keys = keys; + this.tokens = tokens; + } - public List<Token> getTokens() { - return tokens; - } + public List<Token> getTokens() { + return tokens; + } - public void splitIntoNgrams(int minGram, int maxGram){ - List<Token> ngramTokens = new ArrayList<>(); - for(int n = minGram; n <= maxGram; n++){ - for(int i = 0; i < tokens.size() - n + 1; i++){ - List<Token> subList = tokens.subList(i, i+n); - Token token = new Token(subList); - ngramTokens.add(token); - } - } - tokens = ngramTokens; - } + public void splitIntoNgrams(int minGram, int maxGram){ + List<Token> ngramTokens = new ArrayList<>(); + for(int n = minGram; n <= maxGram; n++){ + for(int i = 0; i < tokens.size() - n + 1; i++){ + List<Token> subList = tokens.subList(i, i+n); + Token token = new Token(subList); + ngramTokens.add(token); + } + } + tokens = ngramTokens; + } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/Token.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/Token.java index 990f7e0f71..1be29ed764 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/Token.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/Token.java @@ -22,103 +22,97 @@ package org.apache.sysds.runtime.transform.tokenize; import org.apache.sysds.runtime.DMLRuntimeException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; -import java.util.stream.IntStream; public class Token { - public static final String EMPTY_TOKEN = ""; - - public static class SubToken { - private final String text; - private final long startIndex; - private final long endIndex; - - public SubToken(String token, long startIndex) { - this.text = token; - this.startIndex = startIndex; - this.endIndex = startIndex + token.length(); - } - - @Override - public String toString() { - return "SubToken{" + - "textToken='" + text + '\'' + - ", startIndex=" + startIndex + - ", endIndex=" + endIndex + - '}'; - } - } - - private List<SubToken> subTokens; - - private Token(int subListSize){ - subTokens = new ArrayList<>(subListSize); - } - - public Token(String token, long startIndex) { - this(1); - subTokens.add(new SubToken(token, startIndex)); - } - - public Token(List<String> tokens, List<Long> startIndex){ - this(tokens.size()); - if(tokens.size() != startIndex.size()) - throw new DMLRuntimeException("Cannot create token from mismatched input sizes"); - for(int i = 0; i < tokens.size(); i++){ - subTokens.add(new SubToken(tokens.get(i), startIndex.get(i))); - } - } - - public Token(List<Token> subList) { - this(getNumSubTokens(subList)); - for(Token token: subList){ - subTokens.addAll(token.subTokens); - } - } - - private static int getNumSubTokens(List<Token> tokens){ - int sum = 0; - for (Token token : tokens) { - sum += token.getNumSubTokens(); - } - return sum; - } - - public int getNumSubTokens(){ - return subTokens.size(); - } - - public long getStartIndex(int subTokenIndex){ - return subTokens.get(subTokenIndex).startIndex; - } - - @Override - public int hashCode() { - return toString().hashCode(); - } - - @Override - public String toString() { - if(subTokens.size() == 0){ - return EMPTY_TOKEN; - } - if(subTokens.size() == 1){ - return subTokens.get(0).text; - } - StringBuilder sb = new StringBuilder().append("\"('"); - for(int i = 0; i < subTokens.size(); i++){ - sb.append(subTokens.get(i).text); - if(i < subTokens.size()-1) - sb.append("', '"); - } - sb.append("')\""); - //return "\"('" + subTokens.stream().map(subToken -> subToken.text).collect(Collectors.joining("', '")) + "')\""; - return sb.toString(); - } - - + public static final String EMPTY_TOKEN = ""; + + public static class SubToken { + private final String text; + private final long startIndex; + private final long endIndex; + + public SubToken(String token, long startIndex) { + this.text = token; + this.startIndex = startIndex; + this.endIndex = startIndex + token.length(); + } + + @Override + public String toString() { + return "SubToken{" + + "textToken='" + text + '\'' + + ", startIndex=" + startIndex + + ", endIndex=" + endIndex + + '}'; + } + } + + private List<SubToken> subTokens; + + private Token(int subListSize){ + subTokens = new ArrayList<>(subListSize); + } + + public Token(String token, long startIndex) { + this(1); + subTokens.add(new SubToken(token, startIndex)); + } + + public Token(List<String> tokens, List<Long> startIndex){ + this(tokens.size()); + if(tokens.size() != startIndex.size()) + throw new DMLRuntimeException("Cannot create token from mismatched input sizes"); + for(int i = 0; i < tokens.size(); i++){ + subTokens.add(new SubToken(tokens.get(i), startIndex.get(i))); + } + } + + public Token(List<Token> subList) { + this(getNumSubTokens(subList)); + for(Token token: subList){ + subTokens.addAll(token.subTokens); + } + } + + private static int getNumSubTokens(List<Token> tokens){ + int sum = 0; + for (Token token : tokens) { + sum += token.getNumSubTokens(); + } + return sum; + } + + public int getNumSubTokens(){ + return subTokens.size(); + } + + public long getStartIndex(int subTokenIndex){ + return subTokens.get(subTokenIndex).startIndex; + } + + @Override + public int hashCode() { + return toString().hashCode(); + } + + @Override + public String toString() { + if(subTokens.size() == 0){ + return EMPTY_TOKEN; + } + if(subTokens.size() == 1){ + return subTokens.get(0).text; + } + StringBuilder sb = new StringBuilder().append("\"('"); + for(int i = 0; i < subTokens.size(); i++){ + sb.append(subTokens.get(i).text); + if(i < subTokens.size()-1) + sb.append("', '"); + } + sb.append("')\""); + //return "\"('" + subTokens.stream().map(subToken -> subToken.text).collect(Collectors.joining("', '")) + "')\""; + return sb.toString(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/Tokenizer.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/Tokenizer.java index 6ba0dcb4f8..de2bb9a5a7 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/Tokenizer.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/Tokenizer.java @@ -23,7 +23,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types; import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplier; @@ -46,197 +45,195 @@ import java.util.concurrent.ExecutionException; public class Tokenizer implements Serializable { - private static final long serialVersionUID = 7155673772374114577L; - protected static final Log LOG = LogFactory.getLog(Tokenizer.class.getName()); - private static final boolean MULTI_THREADED_STAGES_TOKENIZER = false; - public static final int TOKENIZE_NUM_BLOCKS = ConfigurationManager.getNumberTokenizeBlocks(); - - private DocumentRepresentation[] internalRepresentation = null; - private final TokenizerBuilder tokenizerBuilder; - private final TokenizerApplier tokenizerApplier; - - protected Tokenizer(TokenizerBuilder tokenizerBuilder, TokenizerApplier tokenizerApplier) { - this.tokenizerBuilder = tokenizerBuilder; - this.tokenizerApplier = tokenizerApplier; - } - - public Types.ValueType[] getSchema() { - return tokenizerApplier.getOutSchema(); - } - - public int getMaxNumRows(int inRows) { - return tokenizerApplier.getMaxNumRows(inRows); - } - - public int getNumRowsEstimate(){ - // Estimate upperbound because e.g. Count Applier has less since it only outputs each unique token once - if(internalRepresentation != null){ - if(tokenizerApplier.isWideFormat()) { - return internalRepresentation.length; - } - if(tokenizerApplier.hasPadding()) { - return internalRepresentation.length * tokenizerApplier.getMaxTokens(); - } - return Arrays.stream(internalRepresentation).mapToInt(doc -> Math.min(doc.tokens.size(), tokenizerApplier.getMaxTokens())).sum(); - } - throw new DMLRuntimeException("Internal Token Representation was not computed yet. Can not get exact size."); - } - - public long getNumCols() { - return tokenizerApplier.getNumCols(); - } - - public void allocateInternalRepresentation(int numDocuments){ - internalRepresentation = new DocumentRepresentation[numDocuments]; - tokenizerApplier.allocateInternalMeta(numDocuments); - } - - public FrameBlock tokenize(FrameBlock in) { - return tokenize(in, 1); - } - - public FrameBlock tokenize(FrameBlock in, int k) { - allocateInternalRepresentation(in.getNumRows()); - FrameBlock out = new FrameBlock(this.getSchema()); - if (k > 1 && !MULTI_THREADED_STAGES_TOKENIZER) { - DependencyThreadPool pool = new DependencyThreadPool(k); - LOG.debug("Tokenizing with full DAG on " + k + " Threads"); - try { - List<DependencyTask<?>> tokenizeTasks = getTokenizeTasks(in, out, pool); - int lastRow = pool.submitAllAndWait(tokenizeTasks).stream().map(s -> s == null? 0 :(Integer)s).max(Integer::compare).get(); - if(lastRow != out.getNumRows()){ - out = out.slice(0, lastRow - 1, 0, out.getNumColumns() - 1, null); - } - } catch (ExecutionException | InterruptedException e) { - LOG.error("MT tokenize failed"); - e.printStackTrace(); - } - pool.shutdown(); - } else { - build(in, k); - out.ensureAllocatedColumns(tokenizerApplier.getNumRows(this.internalRepresentation)); - out = apply(out, k); - } - return out; - } - - private List<DependencyTask<?>> getTokenizeTasks(FrameBlock in, FrameBlock out, DependencyThreadPool pool) { - // TODO further optimisation of task graph to reduce memory usage! - // TODO add cache awareness - List<DependencyTask<?>> tasks = new ArrayList<>(); - Map<Integer[], Integer[]> depMap = new HashMap<>(); - tasks.add(DependencyThreadPool.createDependencyTask(new AllocateOutputFrame(this, out))); - List<DependencyTask<?>> buildTasks = getBuildTasks(in); // First half is builder build second half is applier build, dependencies already done - tasks.addAll(buildTasks); - List<DependencyTask<?>> applyTasks = tokenizerApplier.getApplyTasks(this.internalRepresentation, out); - if(applyTasks.size() != buildTasks.size() / 2) - throw new DMLRuntimeException("Different block sizes between build and apply tasks currently not supported"); - // Builder creates internal representation for a given section - // Applier builder creates additional meta information which will be needed in the apply step - // If there is long representation and no padding: - // - Count and Hash apply tasks have dependencies to the metadata build task of all previous chunks due to "getOutputRow". - // e.g. apply task starting at row 100 with block size 50 has dependencies to the ApplierBuildTask responsible for sections [0-49] and [50-99]. - // - Same for Position only they are only dependent on the internal representation creation since it does not have metadata. - if(!tokenizerApplier.isWideFormat() || !tokenizerApplier.hasPadding()){ - int buildTaskOffset; - if(tokenizerApplier instanceof TokenizerApplierPosition){ - buildTaskOffset = 0; - } - else if (tokenizerApplier instanceof TokenizerApplierCount || tokenizerApplier instanceof TokenizerApplierHash) { - buildTaskOffset = applyTasks.size(); - } - else{ - throw new DMLRuntimeException("Unknown TokenizerApplier"); - } - depMap.put(new Integer[] {0, 1}, new Integer[]{1, (buildTasks.size()/2) + 1}); - depMap.put(new Integer[] {tasks.size(), tasks.size()+applyTasks.size()}, new Integer[]{0, 1}); - for(int i = 0; i < applyTasks.size(); i++){ - depMap.put(new Integer[] {tasks.size() + i, tasks.size()+applyTasks.size()}, new Integer[]{1+buildTaskOffset + i, 2+buildTaskOffset + i}); - } - } - tasks.addAll(applyTasks); - List<List<? extends Callable<?>>> deps = new ArrayList<>(Collections.nCopies(tasks.size(), null)); - DependencyThreadPool.createDependencyList(tasks, depMap, deps); - return DependencyThreadPool.createDependencyTasks(tasks, deps); - } - - public FrameBlock apply(FrameBlock out, int k) { - int lastRow = -1; - if(k > 1){ - DependencyThreadPool pool = new DependencyThreadPool(k); - try{ - List<DependencyTask<?>> taskList = tokenizerApplier.getApplyTasks(this.internalRepresentation, out); - lastRow = pool.submitAllAndWait(taskList).stream().map(s -> (Integer)s).max(Integer::compare).get(); - } - catch(ExecutionException | InterruptedException e) { - LOG.error("MT Tokenizer apply failed"); - e.printStackTrace(); - } - pool.shutdown(); - - }else{ - lastRow = tokenizerApplier.applyInternalRepresentation(this.internalRepresentation, out); - } - if(lastRow != out.getNumRows()){ - out = out.slice(0, lastRow - 1, 0, out.getNumColumns() - 1, null); - } - - return out; - } - - public List<DependencyTask<?>> getBuildTasks(FrameBlock in){ - List<DependencyTask<?>> tasks = tokenizerBuilder.getTasks(in, this.internalRepresentation); - List<DependencyTask<?>> applierBuildTaskList = tokenizerApplier.getBuildTasks(this.internalRepresentation); - if(tasks.size() != applierBuildTaskList.size()) - throw new DMLRuntimeException("Cannot create dependencies for mismatched array sizes"); - tasks.addAll(applierBuildTaskList); - List<List<? extends Callable<?>>> deps = new ArrayList<>(Collections.nCopies(tasks.size(), null)); - Map<Integer[], Integer[]> depMap = new HashMap<>(); - for(int i = 0; i < tasks.size() / 2; i++){ - depMap.put(new Integer[]{i+applierBuildTaskList.size(), i+applierBuildTaskList.size() + 1}, new Integer[] {i, i+1}); - } - DependencyThreadPool.createDependencyList(tasks, depMap, deps); - tasks = DependencyThreadPool.createDependencyTasks(tasks, deps); - return tasks; - } - - public void build(FrameBlock in, int k){ - tokenizerApplier.allocateInternalMeta(in.getNumRows()); - if(k > 1){ - DependencyThreadPool pool = new DependencyThreadPool(k); - try{ - pool.submitAllAndWait(getBuildTasks(in)); - } - catch(ExecutionException | InterruptedException e) { - LOG.error("MT Tokenizer build failed"); - e.printStackTrace(); - } - pool.shutdown(); - - }else{ - tokenizerBuilder.createInternalRepresentation(in, this.internalRepresentation); - tokenizerApplier.build(this.internalRepresentation, 0, -1); - } - } - - - protected static class AllocateOutputFrame implements Callable<Object>{ - - protected final Tokenizer _tokenizer; - protected final FrameBlock _out; - - protected AllocateOutputFrame(Tokenizer tokenizer, - FrameBlock out){ - this._tokenizer = tokenizer; - this._out = out; - } - - @Override - public Object call() throws Exception { - _out.ensureAllocatedColumns(_tokenizer.getNumRowsEstimate()); - return null; - } - } - - + private static final long serialVersionUID = 7155673772374114577L; + protected static final Log LOG = LogFactory.getLog(Tokenizer.class.getName()); + private static final boolean MULTI_THREADED_STAGES_TOKENIZER = false; + public static final int TOKENIZE_NUM_BLOCKS = ConfigurationManager.getNumberTokenizeBlocks(); + + private DocumentRepresentation[] internalRepresentation = null; + private final TokenizerBuilder tokenizerBuilder; + private final TokenizerApplier tokenizerApplier; + + protected Tokenizer(TokenizerBuilder tokenizerBuilder, TokenizerApplier tokenizerApplier) { + this.tokenizerBuilder = tokenizerBuilder; + this.tokenizerApplier = tokenizerApplier; + } + + public Types.ValueType[] getSchema() { + return tokenizerApplier.getOutSchema(); + } + + public int getMaxNumRows(int inRows) { + return tokenizerApplier.getMaxNumRows(inRows); + } + + public int getNumRowsEstimate(){ + // Estimate upperbound because e.g. Count Applier has less since it only outputs each unique token once + if(internalRepresentation != null){ + if(tokenizerApplier.isWideFormat()) { + return internalRepresentation.length; + } + if(tokenizerApplier.hasPadding()) { + return internalRepresentation.length * tokenizerApplier.getMaxTokens(); + } + return Arrays.stream(internalRepresentation).mapToInt(doc -> Math.min(doc.tokens.size(), tokenizerApplier.getMaxTokens())).sum(); + } + throw new DMLRuntimeException("Internal Token Representation was not computed yet. Can not get exact size."); + } + + public long getNumCols() { + return tokenizerApplier.getNumCols(); + } + + public void allocateInternalRepresentation(int numDocuments){ + internalRepresentation = new DocumentRepresentation[numDocuments]; + tokenizerApplier.allocateInternalMeta(numDocuments); + } + + public FrameBlock tokenize(FrameBlock in) { + return tokenize(in, 1); + } + + public FrameBlock tokenize(FrameBlock in, int k) { + allocateInternalRepresentation(in.getNumRows()); + FrameBlock out = new FrameBlock(this.getSchema()); + if (k > 1 && !MULTI_THREADED_STAGES_TOKENIZER) { + DependencyThreadPool pool = new DependencyThreadPool(k); + LOG.debug("Tokenizing with full DAG on " + k + " Threads"); + try { + List<DependencyTask<?>> tokenizeTasks = getTokenizeTasks(in, out, pool); + int lastRow = pool.submitAllAndWait(tokenizeTasks).stream().map(s -> s == null? 0 :(Integer)s).max(Integer::compare).get(); + if(lastRow != out.getNumRows()){ + out = out.slice(0, lastRow - 1, 0, out.getNumColumns() - 1, null); + } + } catch (ExecutionException | InterruptedException e) { + LOG.error("MT tokenize failed"); + e.printStackTrace(); + } + pool.shutdown(); + } else { + build(in, k); + out.ensureAllocatedColumns(tokenizerApplier.getNumRows(this.internalRepresentation)); + out = apply(out, k); + } + return out; + } + + private List<DependencyTask<?>> getTokenizeTasks(FrameBlock in, FrameBlock out, DependencyThreadPool pool) { + // TODO further optimisation of task graph to reduce memory usage! + // TODO add cache awareness + List<DependencyTask<?>> tasks = new ArrayList<>(); + Map<Integer[], Integer[]> depMap = new HashMap<>(); + tasks.add(DependencyThreadPool.createDependencyTask(new AllocateOutputFrame(this, out))); + List<DependencyTask<?>> buildTasks = getBuildTasks(in); // First half is builder build second half is applier build, dependencies already done + tasks.addAll(buildTasks); + List<DependencyTask<?>> applyTasks = tokenizerApplier.getApplyTasks(this.internalRepresentation, out); + if(applyTasks.size() != buildTasks.size() / 2) + throw new DMLRuntimeException("Different block sizes between build and apply tasks currently not supported"); + // Builder creates internal representation for a given section + // Applier builder creates additional meta information which will be needed in the apply step + // If there is long representation and no padding: + // - Count and Hash apply tasks have dependencies to the metadata build task of all previous chunks due to "getOutputRow". + // e.g. apply task starting at row 100 with block size 50 has dependencies to the ApplierBuildTask responsible for sections [0-49] and [50-99]. + // - Same for Position only they are only dependent on the internal representation creation since it does not have metadata. + if(!tokenizerApplier.isWideFormat() || !tokenizerApplier.hasPadding()){ + int buildTaskOffset; + if(tokenizerApplier instanceof TokenizerApplierPosition){ + buildTaskOffset = 0; + } + else if (tokenizerApplier instanceof TokenizerApplierCount || tokenizerApplier instanceof TokenizerApplierHash) { + buildTaskOffset = applyTasks.size(); + } + else{ + throw new DMLRuntimeException("Unknown TokenizerApplier"); + } + depMap.put(new Integer[] {0, 1}, new Integer[]{1, (buildTasks.size()/2) + 1}); + depMap.put(new Integer[] {tasks.size(), tasks.size()+applyTasks.size()}, new Integer[]{0, 1}); + for(int i = 0; i < applyTasks.size(); i++){ + depMap.put(new Integer[] {tasks.size() + i, tasks.size()+applyTasks.size()}, new Integer[]{1+buildTaskOffset + i, 2+buildTaskOffset + i}); + } + } + tasks.addAll(applyTasks); + List<List<? extends Callable<?>>> deps = new ArrayList<>(Collections.nCopies(tasks.size(), null)); + DependencyThreadPool.createDependencyList(tasks, depMap, deps); + return DependencyThreadPool.createDependencyTasks(tasks, deps); + } + + public FrameBlock apply(FrameBlock out, int k) { + int lastRow = -1; + if(k > 1){ + DependencyThreadPool pool = new DependencyThreadPool(k); + try{ + List<DependencyTask<?>> taskList = tokenizerApplier.getApplyTasks(this.internalRepresentation, out); + lastRow = pool.submitAllAndWait(taskList).stream().map(s -> (Integer)s).max(Integer::compare).get(); + } + catch(ExecutionException | InterruptedException e) { + LOG.error("MT Tokenizer apply failed"); + e.printStackTrace(); + } + pool.shutdown(); + + }else{ + lastRow = tokenizerApplier.applyInternalRepresentation(this.internalRepresentation, out); + } + if(lastRow != out.getNumRows()){ + out = out.slice(0, lastRow - 1, 0, out.getNumColumns() - 1, null); + } + + return out; + } + + public List<DependencyTask<?>> getBuildTasks(FrameBlock in){ + List<DependencyTask<?>> tasks = tokenizerBuilder.getTasks(in, this.internalRepresentation); + List<DependencyTask<?>> applierBuildTaskList = tokenizerApplier.getBuildTasks(this.internalRepresentation); + if(tasks.size() != applierBuildTaskList.size()) + throw new DMLRuntimeException("Cannot create dependencies for mismatched array sizes"); + tasks.addAll(applierBuildTaskList); + List<List<? extends Callable<?>>> deps = new ArrayList<>(Collections.nCopies(tasks.size(), null)); + Map<Integer[], Integer[]> depMap = new HashMap<>(); + for(int i = 0; i < tasks.size() / 2; i++){ + depMap.put(new Integer[]{i+applierBuildTaskList.size(), i+applierBuildTaskList.size() + 1}, new Integer[] {i, i+1}); + } + DependencyThreadPool.createDependencyList(tasks, depMap, deps); + tasks = DependencyThreadPool.createDependencyTasks(tasks, deps); + return tasks; + } + + public void build(FrameBlock in, int k){ + tokenizerApplier.allocateInternalMeta(in.getNumRows()); + if(k > 1){ + DependencyThreadPool pool = new DependencyThreadPool(k); + try{ + pool.submitAllAndWait(getBuildTasks(in)); + } + catch(ExecutionException | InterruptedException e) { + LOG.error("MT Tokenizer build failed"); + e.printStackTrace(); + } + pool.shutdown(); + + }else{ + tokenizerBuilder.createInternalRepresentation(in, this.internalRepresentation); + tokenizerApplier.build(this.internalRepresentation, 0, -1); + } + } + + + protected static class AllocateOutputFrame implements Callable<Object>{ + + protected final Tokenizer _tokenizer; + protected final FrameBlock _out; + + protected AllocateOutputFrame(Tokenizer tokenizer, + FrameBlock out){ + this._tokenizer = tokenizer; + this._out = out; + } + + @Override + public Object call() throws Exception { + _out.ensureAllocatedColumns(_tokenizer.getNumRowsEstimate()); + return null; + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerFactory.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerFactory.java index 218bb5ee4e..2965ba8ebf 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/TokenizerFactory.java @@ -30,92 +30,89 @@ import org.apache.sysds.runtime.transform.tokenize.builder.TokenizerBuilderWhite import org.apache.wink.json4j.JSONObject; import org.apache.wink.json4j.JSONArray; -import java.util.ArrayList; -import java.util.List; - public class TokenizerFactory { - public static Tokenizer createTokenizer(String spec, int maxTokens) { - Tokenizer tokenizer = null; - - try { - //parse transform specification - JSONObject jSpec = new JSONObject(spec); - - // tokenization needs an algorithm (with algorithm specific params) - String algo = jSpec.getString("algo"); - JSONObject algoParams = null; - if (jSpec.has("algo_params")) { - algoParams = jSpec.getJSONObject("algo_params"); - } - - // tokenization needs an output representation (with representation specific params) - String out = jSpec.getString("out"); - JSONObject outParams = null; - if (jSpec.has("out_params")) { - outParams = jSpec.getJSONObject("out_params"); - } - - // tokenization needs a text column to tokenize - int tokenizeCol = jSpec.getInt("tokenize_col"); - - // tokenization needs one or more idCols that define the document and are replicated per token - JSONArray idColsJsonArray = jSpec.getJSONArray("id_cols"); - int[] idCols = new int[idColsJsonArray.length()]; - for (int i=0; i < idColsJsonArray.length(); i++) { - idCols[i] = idColsJsonArray.getInt(i); - } - // Output schema is derived from specified id cols - int numIdCols = idCols.length; - - // get difference between long and wide format - boolean wideFormat = false; // long format is default - if (jSpec.has("format_wide")) { - wideFormat = jSpec.getBoolean("format_wide"); - } - - boolean applyPadding = false; // no padding is default - if (jSpec.has("apply_padding")) { - applyPadding = jSpec.getBoolean("apply_padding"); - } - - TokenizerBuilder tokenizerBuilder; - TokenizerApplier tokenizerApplier; - - // Note that internal representation should be independent of output representation - - // Algorithm to transform tokens into internal token representation - switch (algo) { - case "split": - tokenizerBuilder = new TokenizerBuilderWhitespaceSplit(idCols, tokenizeCol, algoParams); - break; - case "ngram": - tokenizerBuilder = new TokenizerBuilderNgram(idCols, tokenizeCol, algoParams); - break; - default: - throw new IllegalArgumentException("Algorithm {algo=" + algo + "} is not supported."); - } - - // Transform tokens to output representation - switch (out) { - case "count": - tokenizerApplier = new TokenizerApplierCount(numIdCols, maxTokens, wideFormat, applyPadding, outParams); - break; - case "position": - tokenizerApplier = new TokenizerApplierPosition(numIdCols, maxTokens, wideFormat, applyPadding); - break; - case "hash": - tokenizerApplier = new TokenizerApplierHash(numIdCols, maxTokens, wideFormat, applyPadding, outParams); - break; - default: - throw new IllegalArgumentException("Output representation {out=" + out + "} is not supported."); - } - - tokenizer = new Tokenizer(tokenizerBuilder, tokenizerApplier); - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); - } - return tokenizer; - } + public static Tokenizer createTokenizer(String spec, int maxTokens) { + Tokenizer tokenizer = null; + + try { + //parse transform specification + JSONObject jSpec = new JSONObject(spec); + + // tokenization needs an algorithm (with algorithm specific params) + String algo = jSpec.getString("algo"); + JSONObject algoParams = null; + if (jSpec.has("algo_params")) { + algoParams = jSpec.getJSONObject("algo_params"); + } + + // tokenization needs an output representation (with representation specific params) + String out = jSpec.getString("out"); + JSONObject outParams = null; + if (jSpec.has("out_params")) { + outParams = jSpec.getJSONObject("out_params"); + } + + // tokenization needs a text column to tokenize + int tokenizeCol = jSpec.getInt("tokenize_col"); + + // tokenization needs one or more idCols that define the document and are replicated per token + JSONArray idColsJsonArray = jSpec.getJSONArray("id_cols"); + int[] idCols = new int[idColsJsonArray.length()]; + for (int i=0; i < idColsJsonArray.length(); i++) { + idCols[i] = idColsJsonArray.getInt(i); + } + // Output schema is derived from specified id cols + int numIdCols = idCols.length; + + // get difference between long and wide format + boolean wideFormat = false; // long format is default + if (jSpec.has("format_wide")) { + wideFormat = jSpec.getBoolean("format_wide"); + } + + boolean applyPadding = false; // no padding is default + if (jSpec.has("apply_padding")) { + applyPadding = jSpec.getBoolean("apply_padding"); + } + + TokenizerBuilder tokenizerBuilder; + TokenizerApplier tokenizerApplier; + + // Note that internal representation should be independent of output representation + + // Algorithm to transform tokens into internal token representation + switch (algo) { + case "split": + tokenizerBuilder = new TokenizerBuilderWhitespaceSplit(idCols, tokenizeCol, algoParams); + break; + case "ngram": + tokenizerBuilder = new TokenizerBuilderNgram(idCols, tokenizeCol, algoParams); + break; + default: + throw new IllegalArgumentException("Algorithm {algo=" + algo + "} is not supported."); + } + + // Transform tokens to output representation + switch (out) { + case "count": + tokenizerApplier = new TokenizerApplierCount(numIdCols, maxTokens, wideFormat, applyPadding, outParams); + break; + case "position": + tokenizerApplier = new TokenizerApplierPosition(numIdCols, maxTokens, wideFormat, applyPadding); + break; + case "hash": + tokenizerApplier = new TokenizerApplierHash(numIdCols, maxTokens, wideFormat, applyPadding, outParams); + break; + default: + throw new IllegalArgumentException("Output representation {out=" + out + "} is not supported."); + } + + tokenizer = new Tokenizer(tokenizerBuilder, tokenizerApplier); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + return tokenizer; + } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplier.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplier.java index de37e51516..f68db88a2c 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplier.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplier.java @@ -38,176 +38,175 @@ import static org.apache.sysds.runtime.transform.tokenize.Tokenizer.TOKENIZE_NUM import static org.apache.sysds.runtime.util.UtilFunctions.getBlockSizes; public abstract class TokenizerApplier implements Serializable { - - protected static final Log LOG = LogFactory.getLog(TokenizerApplier.class.getName()); - - public static final String PADDING_STRING = ""; - - protected final int numIdCols; - protected final int maxTokens; - protected final boolean wideFormat; - protected final boolean applyPadding; - - public TokenizerApplier(int numIdCols, int maxTokens, boolean wideFormat, boolean applyPadding){ - this.numIdCols = numIdCols; - this.maxTokens = maxTokens; - this.wideFormat = wideFormat; - this.applyPadding = applyPadding; - } - - public int applyInternalRepresentation(DocumentRepresentation[] internalRepresentation, FrameBlock out){ - return applyInternalRepresentation(internalRepresentation, out, 0, -1); - } - abstract int applyInternalRepresentation(DocumentRepresentation[] internalRepresentation, FrameBlock out, int startRow, int blk); - - public void build(DocumentRepresentation[] internalRepresentation, int inputRowStart, int blk){ } - - public List<DependencyTask<?>> getBuildTasks(DocumentRepresentation[] internalRepresentation){ - int nRows = internalRepresentation.length; - List<Callable<Object>> tasks = new ArrayList<>(); - int[] blockSizes = getBlockSizes(nRows, TOKENIZE_NUM_BLOCKS); - if(blockSizes.length == 1){ - tasks.add(new TokenizerApplierBuildTask<>(this, internalRepresentation, 0, -1)); - } - else { - for(int startRow = 0, i = 0; i < blockSizes.length; startRow+=blockSizes[i], i++){ - tasks.add(new TokenizerApplierBuildTask<>(this, internalRepresentation, startRow, blockSizes[i])); - } - } - return DependencyThreadPool.createDependencyTasks(tasks, null); - } - - public List<DependencyTask<?>> getApplyTasks(DocumentRepresentation[] internalRepresentation, FrameBlock out) { - int nRows = internalRepresentation.length; - List<Callable<Object>> tasks = new ArrayList<>(); - int[] blockSizes = getBlockSizes(nRows, TOKENIZE_NUM_BLOCKS); - if(blockSizes.length == 1){ - tasks.add(new TokenizerApplyTask<>(this, out, internalRepresentation, 0, -1)); - } - else { - for(int startRow = 0, i = 0; i < blockSizes.length; startRow+=blockSizes[i], i++){ - tasks.add(new TokenizerApplyTask<>(this, out, internalRepresentation, startRow, blockSizes[i])); - } - } - return DependencyThreadPool.createDependencyTasks(tasks, null); - } - - protected int setKeys(int row, List<Object> keys, FrameBlock out){ - int col = 0; - for(; col < keys.size(); col++){ - out.set(row, col, keys.get(col)); - } - return col; - } - - protected int applyPaddingLong(int startRow, int numTokens, List<Object> keys, FrameBlock out, Object val1, Object val2){ - int row = startRow; - for (; numTokens < maxTokens; numTokens++, row++){ - int col = setKeys(row, keys, out); - out.set(row, col, val1); - out.set(row, col+1, val2); - } - return row; - } - - protected void applyPaddingWide(int row, int offset, int startToken, FrameBlock out, Object padding){ - int token = startToken; - for (; token < maxTokens; token++) { - out.set(row, offset+token, padding); - } - } - - public abstract Types.ValueType[] getOutSchema(); - - public boolean hasPadding(){ - return applyPadding; - } - - public int getMaxTokens(){ - return maxTokens; - } - - public int getMaxNumRows(int inRows) { - if (wideFormat) { - return inRows; - } else { - return inRows * maxTokens; - } - } - public abstract int getNumRows(DocumentRepresentation[] internalRepresentation); - - public <T, E> int getOutputRow(int inputRowStart, List<Map<T, E>> internalData){ - if(wideFormat) - return inputRowStart; - if(applyPadding) - return maxTokens * inputRowStart; - return internalData.stream().limit(inputRowStart).mapToInt(hashMap -> Math.min(hashMap.size(), maxTokens)).sum(); - } - - public int getOutputRow(int inputRowStart, DocumentRepresentation[] internalData){ - if(wideFormat) - return inputRowStart; - if(applyPadding) - return maxTokens * inputRowStart; - return Arrays.stream(internalData).limit(inputRowStart).mapToInt(doc -> Math.min(doc.tokens.size(), maxTokens)).sum(); - } - - public long getNumCols() { - return this.getOutSchema().length; - } - - public boolean isWideFormat() { - return wideFormat; - } - - public void allocateInternalMeta(int numDocuments) { } - - - protected static class TokenizerApplyTask<T extends TokenizerApplier> implements Callable<Object>{ - - protected final T _tokenizerApplier; - protected final FrameBlock _output; - protected final DocumentRepresentation[] _internalRepresentation; - protected final int _rowStart; - protected final int _blk; - - protected TokenizerApplyTask(T tokenizerApplier, FrameBlock out, - DocumentRepresentation[] internalRepresentation, - int rowStart, int blk){ - this._tokenizerApplier = tokenizerApplier; - this._output = out; - this._internalRepresentation = internalRepresentation; - this._rowStart = rowStart; - this._blk = blk; - } - - @Override - public Object call() throws Exception { - return this._tokenizerApplier.applyInternalRepresentation(this._internalRepresentation, this._output, this._rowStart, this._blk); - } - } - - protected static class TokenizerApplierBuildTask<T extends TokenizerApplier> implements Callable<Object>{ - - protected final T _tokenizerApplier; - protected final DocumentRepresentation[] _internalRepresentation; - protected final int _rowStart; - protected final int _blk; - - protected TokenizerApplierBuildTask(T tokenizerApplier, - DocumentRepresentation[] internalRepresentation, - int rowStart, int blk){ - this._tokenizerApplier = tokenizerApplier; - this._internalRepresentation = internalRepresentation; - this._rowStart = rowStart; - this._blk = blk; - } - - @Override - public Object call() throws Exception { - this._tokenizerApplier.build(this._internalRepresentation, this._rowStart, this._blk); - return null; - } - } - + private static final long serialVersionUID = 39116559705096787L; + + protected static final Log LOG = LogFactory.getLog(TokenizerApplier.class.getName()); + + public static final String PADDING_STRING = ""; + + protected final int numIdCols; + protected final int maxTokens; + protected final boolean wideFormat; + protected final boolean applyPadding; + + public TokenizerApplier(int numIdCols, int maxTokens, boolean wideFormat, boolean applyPadding){ + this.numIdCols = numIdCols; + this.maxTokens = maxTokens; + this.wideFormat = wideFormat; + this.applyPadding = applyPadding; + } + + public int applyInternalRepresentation(DocumentRepresentation[] internalRepresentation, FrameBlock out){ + return applyInternalRepresentation(internalRepresentation, out, 0, -1); + } + abstract int applyInternalRepresentation(DocumentRepresentation[] internalRepresentation, FrameBlock out, int startRow, int blk); + + public void build(DocumentRepresentation[] internalRepresentation, int inputRowStart, int blk){ } + + public List<DependencyTask<?>> getBuildTasks(DocumentRepresentation[] internalRepresentation){ + int nRows = internalRepresentation.length; + List<Callable<Object>> tasks = new ArrayList<>(); + int[] blockSizes = getBlockSizes(nRows, TOKENIZE_NUM_BLOCKS); + if(blockSizes.length == 1){ + tasks.add(new TokenizerApplierBuildTask<>(this, internalRepresentation, 0, -1)); + } + else { + for(int startRow = 0, i = 0; i < blockSizes.length; startRow+=blockSizes[i], i++){ + tasks.add(new TokenizerApplierBuildTask<>(this, internalRepresentation, startRow, blockSizes[i])); + } + } + return DependencyThreadPool.createDependencyTasks(tasks, null); + } + + public List<DependencyTask<?>> getApplyTasks(DocumentRepresentation[] internalRepresentation, FrameBlock out) { + int nRows = internalRepresentation.length; + List<Callable<Object>> tasks = new ArrayList<>(); + int[] blockSizes = getBlockSizes(nRows, TOKENIZE_NUM_BLOCKS); + if(blockSizes.length == 1){ + tasks.add(new TokenizerApplyTask<>(this, out, internalRepresentation, 0, -1)); + } + else { + for(int startRow = 0, i = 0; i < blockSizes.length; startRow+=blockSizes[i], i++){ + tasks.add(new TokenizerApplyTask<>(this, out, internalRepresentation, startRow, blockSizes[i])); + } + } + return DependencyThreadPool.createDependencyTasks(tasks, null); + } + + protected int setKeys(int row, List<Object> keys, FrameBlock out){ + int col = 0; + for(; col < keys.size(); col++){ + out.set(row, col, keys.get(col)); + } + return col; + } + + protected int applyPaddingLong(int startRow, int numTokens, List<Object> keys, FrameBlock out, Object val1, Object val2){ + int row = startRow; + for (; numTokens < maxTokens; numTokens++, row++){ + int col = setKeys(row, keys, out); + out.set(row, col, val1); + out.set(row, col+1, val2); + } + return row; + } + + protected void applyPaddingWide(int row, int offset, int startToken, FrameBlock out, Object padding){ + int token = startToken; + for (; token < maxTokens; token++) { + out.set(row, offset+token, padding); + } + } + + public abstract Types.ValueType[] getOutSchema(); + + public boolean hasPadding(){ + return applyPadding; + } + + public int getMaxTokens(){ + return maxTokens; + } + + public int getMaxNumRows(int inRows) { + if (wideFormat) { + return inRows; + } else { + return inRows * maxTokens; + } + } + public abstract int getNumRows(DocumentRepresentation[] internalRepresentation); + + public <T, E> int getOutputRow(int inputRowStart, List<Map<T, E>> internalData){ + if(wideFormat) + return inputRowStart; + if(applyPadding) + return maxTokens * inputRowStart; + return internalData.stream().limit(inputRowStart).mapToInt(hashMap -> Math.min(hashMap.size(), maxTokens)).sum(); + } + + public int getOutputRow(int inputRowStart, DocumentRepresentation[] internalData){ + if(wideFormat) + return inputRowStart; + if(applyPadding) + return maxTokens * inputRowStart; + return Arrays.stream(internalData).limit(inputRowStart).mapToInt(doc -> Math.min(doc.tokens.size(), maxTokens)).sum(); + } + + public long getNumCols() { + return this.getOutSchema().length; + } + + public boolean isWideFormat() { + return wideFormat; + } + + public void allocateInternalMeta(int numDocuments) { } + + + protected static class TokenizerApplyTask<T extends TokenizerApplier> implements Callable<Object>{ + + protected final T _tokenizerApplier; + protected final FrameBlock _output; + protected final DocumentRepresentation[] _internalRepresentation; + protected final int _rowStart; + protected final int _blk; + + protected TokenizerApplyTask(T tokenizerApplier, FrameBlock out, + DocumentRepresentation[] internalRepresentation, + int rowStart, int blk){ + this._tokenizerApplier = tokenizerApplier; + this._output = out; + this._internalRepresentation = internalRepresentation; + this._rowStart = rowStart; + this._blk = blk; + } + + @Override + public Object call() throws Exception { + return this._tokenizerApplier.applyInternalRepresentation(this._internalRepresentation, this._output, this._rowStart, this._blk); + } + } + + protected static class TokenizerApplierBuildTask<T extends TokenizerApplier> implements Callable<Object>{ + + protected final T _tokenizerApplier; + protected final DocumentRepresentation[] _internalRepresentation; + protected final int _rowStart; + protected final int _blk; + + protected TokenizerApplierBuildTask(T tokenizerApplier, + DocumentRepresentation[] internalRepresentation, int rowStart, int blk){ + _tokenizerApplier = tokenizerApplier; + _internalRepresentation = internalRepresentation; + _rowStart = rowStart; + _blk = blk; + } + + @Override + public Object call() throws Exception { + this._tokenizerApplier.build(this._internalRepresentation, this._rowStart, this._blk); + return null; + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierCount.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierCount.java index a67467f51e..5298fa4634 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierCount.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierCount.java @@ -28,111 +28,103 @@ import org.apache.wink.json4j.JSONException; import org.apache.wink.json4j.JSONObject; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.TreeMap; import java.util.TreeSet; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; public class TokenizerApplierCount extends TokenizerApplier { - private static final long serialVersionUID = 6382000606237705019L; - public boolean sort_alpha = false; + private static final long serialVersionUID = 6382000606237705019L; + public boolean sort_alpha = false; - private List<Map<String, Integer>> counts; + private List<Map<String, Integer>> counts; - public TokenizerApplierCount(int numIdCols, int maxTokens, boolean wideFormat, boolean applyPadding, JSONObject params) throws JSONException { - super(numIdCols, maxTokens, wideFormat, applyPadding); - if (params != null && params.has("sort_alpha")) { - this.sort_alpha = params.getBoolean("sort_alpha"); - } - } + public TokenizerApplierCount(int numIdCols, int maxTokens, boolean wideFormat, boolean applyPadding, JSONObject params) throws JSONException { + super(numIdCols, maxTokens, wideFormat, applyPadding); + if (params != null && params.has("sort_alpha")) { + this.sort_alpha = params.getBoolean("sort_alpha"); + } + } - @Override - public int getNumRows(DocumentRepresentation[] internalRepresentation) { - if(wideFormat) - return internalRepresentation.length; - if(applyPadding) - return maxTokens * internalRepresentation.length; - return counts.stream().mapToInt(hashMap -> Math.min(hashMap.size(), maxTokens)).sum(); - } + @Override + public int getNumRows(DocumentRepresentation[] internalRepresentation) { + if(wideFormat) + return internalRepresentation.length; + if(applyPadding) + return maxTokens * internalRepresentation.length; + return counts.stream().mapToInt(hashMap -> Math.min(hashMap.size(), maxTokens)).sum(); + } - @Override - public void allocateInternalMeta(int numDocuments) { - counts = new ArrayList<>(Collections.nCopies(numDocuments,null)); - } + @Override + public void allocateInternalMeta(int numDocuments) { + counts = new ArrayList<>(Collections.nCopies(numDocuments,null)); + } - @Override - public void build(DocumentRepresentation[] internalRepresentation, int inputRowStart, int blk){ - int endIndex = getEndIndex(internalRepresentation.length, inputRowStart, blk); - for(int i = inputRowStart; i < endIndex; i++){ - Map<String, Integer> tokenCounts = new HashMap<>(); - for(Token token: internalRepresentation[i].tokens){ - String txt = token.toString(); - Integer count = tokenCounts.getOrDefault(txt, null); - if(count != null) - tokenCounts.put(txt, count + 1); - else - tokenCounts.put(txt, 1); - } - counts.set(i, tokenCounts); - } - } + @Override + public void build(DocumentRepresentation[] internalRepresentation, int inputRowStart, int blk){ + int endIndex = getEndIndex(internalRepresentation.length, inputRowStart, blk); + for(int i = inputRowStart; i < endIndex; i++){ + Map<String, Integer> tokenCounts = new HashMap<>(); + for(Token token: internalRepresentation[i].tokens){ + String txt = token.toString(); + Integer count = tokenCounts.getOrDefault(txt, null); + if(count != null) + tokenCounts.put(txt, count + 1); + else + tokenCounts.put(txt, 1); + } + counts.set(i, tokenCounts); + } + } - @Override - public int applyInternalRepresentation(DocumentRepresentation[] internalRepresentation, FrameBlock out, int inputRowStart, int blk) { - int endIndex = getEndIndex(internalRepresentation.length, inputRowStart, blk); - int outputRow = getOutputRow(inputRowStart, counts); - for(int i = inputRowStart; i < endIndex; i++) { - List<Object> keys = internalRepresentation[i].keys; - // Creating the counts for BoW - Map<String, Integer> tokenCounts = counts.get(i); - // Remove duplicate strings - Collection<String> distinctTokens = tokenCounts.keySet(); - if (this.sort_alpha) { - // Sort alphabetically - distinctTokens = new TreeSet<>(distinctTokens); - } + @Override + public int applyInternalRepresentation(DocumentRepresentation[] internalRepresentation, FrameBlock out, int inputRowStart, int blk) { + int endIndex = getEndIndex(internalRepresentation.length, inputRowStart, blk); + int outputRow = getOutputRow(inputRowStart, counts); + for(int i = inputRowStart; i < endIndex; i++) { + List<Object> keys = internalRepresentation[i].keys; + // Creating the counts for BoW + Map<String, Integer> tokenCounts = counts.get(i); + // Remove duplicate strings + Collection<String> distinctTokens = tokenCounts.keySet(); + if (this.sort_alpha) { + // Sort alphabetically + distinctTokens = new TreeSet<>(distinctTokens); + } - int numTokens = 0; - for (String token: distinctTokens) { - if (numTokens >= maxTokens) { - break; - } - int col = setKeys(outputRow, keys, out); - // Create a row per token - long count = tokenCounts.get(token); - out.set(outputRow, col, token); - out.set(outputRow, col+1, count); - outputRow++; - numTokens++; - } - if(applyPadding){ - outputRow = applyPaddingLong(outputRow, numTokens, keys, out, PADDING_STRING, -1); - } - } - return outputRow; - } - - @Override - public Types.ValueType[] getOutSchema() { - if (wideFormat) { - throw new IllegalArgumentException("Wide Format is not supported for Count Representation."); - } - // Long format only depends on numIdCols - Types.ValueType[] schema = UtilFunctions.nCopies(numIdCols + 2,Types.ValueType.STRING ); - schema[numIdCols + 1] = Types.ValueType.INT64; - return schema; - } + int numTokens = 0; + for (String token: distinctTokens) { + if (numTokens >= maxTokens) { + break; + } + int col = setKeys(outputRow, keys, out); + // Create a row per token + long count = tokenCounts.get(token); + out.set(outputRow, col, token); + out.set(outputRow, col+1, count); + outputRow++; + numTokens++; + } + if(applyPadding){ + outputRow = applyPaddingLong(outputRow, numTokens, keys, out, PADDING_STRING, -1); + } + } + return outputRow; + } + @Override + public Types.ValueType[] getOutSchema() { + if (wideFormat) { + throw new IllegalArgumentException("Wide Format is not supported for Count Representation."); + } + // Long format only depends on numIdCols + Types.ValueType[] schema = UtilFunctions.nCopies(numIdCols + 2,Types.ValueType.STRING ); + schema[numIdCols + 1] = Types.ValueType.INT64; + return schema; + } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierHash.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierHash.java index e9e125ca73..2bbbed6827 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierHash.java @@ -19,19 +19,14 @@ package org.apache.sysds.runtime.transform.tokenize.applier; -import org.apache.commons.lang.ArrayUtils; import org.apache.sysds.common.Types; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation; -import org.apache.sysds.runtime.transform.tokenize.Token; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.wink.json4j.JSONException; import org.apache.wink.json4j.JSONObject; -import scala.Array; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; @@ -168,5 +163,4 @@ public class TokenizerApplierHash extends TokenizerApplier { schema[numIdCols+1] = Types.ValueType.INT64; return schema; } - } diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierPosition.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierPosition.java index c92e86b28d..070974c54e 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierPosition.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierPosition.java @@ -43,8 +43,6 @@ public class TokenizerApplierPosition extends TokenizerApplier { return wideFormat ? internalRepresentation.length : Arrays.stream(internalRepresentation).mapToInt(doc -> applyPadding? maxTokens: Math.min(doc.tokens.size(), maxTokens)).sum(); } - - @Override public int applyInternalRepresentation(DocumentRepresentation[] internalRepresentation, FrameBlock out, int inputRowStart, int blk) { int endIndex = getEndIndex(internalRepresentation.length, inputRowStart, blk); diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilder.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilder.java index 36c0c26e25..94cb0ad614 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilder.java @@ -21,7 +21,6 @@ package org.apache.sysds.runtime.transform.tokenize.builder; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation; -import org.apache.sysds.runtime.transform.tokenize.Tokenizer; import org.apache.sysds.runtime.util.DependencyTask; import org.apache.sysds.runtime.util.DependencyThreadPool; @@ -34,55 +33,52 @@ import static org.apache.sysds.runtime.transform.tokenize.Tokenizer.TOKENIZE_NUM import static org.apache.sysds.runtime.util.UtilFunctions.getBlockSizes; public abstract class TokenizerBuilder implements Serializable { - - - public void createInternalRepresentation(FrameBlock in, DocumentRepresentation[] internalRepresentation) { - createInternalRepresentation(in, internalRepresentation, 0, -1); - } - - public abstract void createInternalRepresentation(FrameBlock in, DocumentRepresentation[] internalRepresentation, int rowStart, int blk); - - public List<DependencyTask<?>> getTasks(FrameBlock in, DocumentRepresentation[] internalRepresentation) { - int nRows = in.getNumRows(); - List<Callable<Object>> tasks = new ArrayList<>(); - int[] blockSizes = getBlockSizes(nRows, TOKENIZE_NUM_BLOCKS); - if(blockSizes.length == 1){ - tasks.add(new TokenizerBuildTask<>(this, in, internalRepresentation, 0, -1)); - } - else { - for(int startRow = 0, i = 0; i < blockSizes.length; startRow+=blockSizes[i], i++){ - tasks.add(new TokenizerBuildTask<>(this, in, internalRepresentation, startRow, blockSizes[i])); - } - } - return DependencyThreadPool.createDependencyTasks(tasks, null); - } - - - protected static class TokenizerBuildTask<T extends TokenizerBuilder> implements Callable<Object>{ - - protected final T _tokenizerBuilder; - protected final FrameBlock _input; - protected final DocumentRepresentation[] _internalRepresentation; - protected final int _rowStart; - protected final int _blk; - - protected TokenizerBuildTask(T tokenizerBuilder, FrameBlock input, - DocumentRepresentation[] internalRepresentation, - int rowStart, int blk){ - this._tokenizerBuilder = tokenizerBuilder; - this._input = input; - this._internalRepresentation = internalRepresentation; - this._rowStart = rowStart; - this._blk = blk; - } - - @Override - public Object call() throws Exception { - this._tokenizerBuilder.createInternalRepresentation(this._input, this._internalRepresentation, this._rowStart, this._blk); - return null; - } - } - - - + private static final long serialVersionUID = -4999630313246644464L; + + public void createInternalRepresentation(FrameBlock in, DocumentRepresentation[] internalRepresentation) { + createInternalRepresentation(in, internalRepresentation, 0, -1); + } + + public abstract void createInternalRepresentation(FrameBlock in, DocumentRepresentation[] internalRepresentation, int rowStart, int blk); + + public List<DependencyTask<?>> getTasks(FrameBlock in, DocumentRepresentation[] internalRepresentation) { + int nRows = in.getNumRows(); + List<Callable<Object>> tasks = new ArrayList<>(); + int[] blockSizes = getBlockSizes(nRows, TOKENIZE_NUM_BLOCKS); + if(blockSizes.length == 1){ + tasks.add(new TokenizerBuildTask<>(this, in, internalRepresentation, 0, -1)); + } + else { + for(int startRow = 0, i = 0; i < blockSizes.length; startRow+=blockSizes[i], i++){ + tasks.add(new TokenizerBuildTask<>(this, in, internalRepresentation, startRow, blockSizes[i])); + } + } + return DependencyThreadPool.createDependencyTasks(tasks, null); + } + + + protected static class TokenizerBuildTask<T extends TokenizerBuilder> implements Callable<Object>{ + + protected final T _tokenizerBuilder; + protected final FrameBlock _input; + protected final DocumentRepresentation[] _internalRepresentation; + protected final int _rowStart; + protected final int _blk; + + protected TokenizerBuildTask(T tokenizerBuilder, FrameBlock input, + DocumentRepresentation[] internalRepresentation, + int rowStart, int blk){ + this._tokenizerBuilder = tokenizerBuilder; + this._input = input; + this._internalRepresentation = internalRepresentation; + this._rowStart = rowStart; + this._blk = blk; + } + + @Override + public Object call() throws Exception { + this._tokenizerBuilder.createInternalRepresentation(this._input, this._internalRepresentation, this._rowStart, this._blk); + return null; + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderNgram.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderNgram.java index 5ea87288b2..be29dc9adc 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderNgram.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderNgram.java @@ -33,69 +33,69 @@ import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; public class TokenizerBuilderNgram extends TokenizerBuilderWhitespaceSplit { - private static final long serialVersionUID = -6297904316677723802L; + private static final long serialVersionUID = -6297904316677723802L; - private enum NgramType{ - DOCUMENT, - TOKEN - } + private enum NgramType{ + DOCUMENT, + TOKEN + } - public int minGram = 1; - public int maxGram = 2; - public NgramType ngramType = NgramType.DOCUMENT; + public int minGram = 1; + public int maxGram = 2; + public NgramType ngramType = NgramType.DOCUMENT; - public TokenizerBuilderNgram(int[] idCols, int tokenizeCol, JSONObject params) throws JSONException { - super(idCols, tokenizeCol, params); - if (params != null && params.has("min_gram")) { - this.minGram = params.getInt("min_gram"); - } - if (params != null && params.has("max_gram")) { - this.maxGram = params.getInt("max_gram"); - } - if (params != null && params.has("ngram_type")){ - String type = params.getString("ngram_type").toLowerCase(); - if(type.equals("document")){ - this.ngramType = NgramType.DOCUMENT; - } else if (type.equals("token")) { - this.ngramType = NgramType.TOKEN; - }else { - throw new DMLRuntimeException("Invalid ngram type, choose between 'token' and 'document'"); - } - } - } + public TokenizerBuilderNgram(int[] idCols, int tokenizeCol, JSONObject params) throws JSONException { + super(idCols, tokenizeCol, params); + if (params != null && params.has("min_gram")) { + this.minGram = params.getInt("min_gram"); + } + if (params != null && params.has("max_gram")) { + this.maxGram = params.getInt("max_gram"); + } + if (params != null && params.has("ngram_type")){ + String type = params.getString("ngram_type").toLowerCase(); + if(type.equals("document")){ + this.ngramType = NgramType.DOCUMENT; + } else if (type.equals("token")) { + this.ngramType = NgramType.TOKEN; + }else { + throw new DMLRuntimeException("Invalid ngram type, choose between 'token' and 'document'"); + } + } + } - public List<Token> splitIntoNgrams(Token token, int minGram, int maxGram){ - if(token.getNumSubTokens() == 0) - throw new DMLRuntimeException("Cannot create ngram of token where there are no subTokens"); - if(token.getNumSubTokens() != 1) - throw new DMLRuntimeException("Cannot create ngram of token where there are more than 1 subTokens"); - String tokenText = token.toString(); - List<Token> newTokens = new ArrayList<>(); - for(int n = minGram; n <= maxGram; n++){ - for(int i = 0; i < tokenText.length() - n + 1; i++){ - String substring = tokenText.substring(i, i+n); - newTokens.add(new Token(substring, token.getStartIndex(0) + i)); - } - } - return newTokens; - } - @Override - public void createInternalRepresentation(FrameBlock in, DocumentRepresentation[] internalRepresentation, int rowStart, int blk) { - super.createInternalRepresentation(in, internalRepresentation, rowStart, blk); - int endIndex = getEndIndex(in.getNumRows(), rowStart, blk); - for(int row = rowStart; row < endIndex; row++){ - DocumentRepresentation documentRepresentation = internalRepresentation[row]; - - if(this.ngramType == NgramType.DOCUMENT){ - documentRepresentation.splitIntoNgrams(this.minGram, this.maxGram); - } else if (this.ngramType == NgramType.TOKEN) { - List<Token> newTokens = new ArrayList<>(); - for (Token wordToken: documentRepresentation.getTokens()) { - newTokens.addAll(splitIntoNgrams(wordToken, this.minGram, this.maxGram)); - } - documentRepresentation.tokens = newTokens; - } - } - } + public List<Token> splitIntoNgrams(Token token, int minGram, int maxGram){ + if(token.getNumSubTokens() == 0) + throw new DMLRuntimeException("Cannot create ngram of token where there are no subTokens"); + if(token.getNumSubTokens() != 1) + throw new DMLRuntimeException("Cannot create ngram of token where there are more than 1 subTokens"); + String tokenText = token.toString(); + List<Token> newTokens = new ArrayList<>(); + for(int n = minGram; n <= maxGram; n++){ + for(int i = 0; i < tokenText.length() - n + 1; i++){ + String substring = tokenText.substring(i, i+n); + newTokens.add(new Token(substring, token.getStartIndex(0) + i)); + } + } + return newTokens; + } + + @Override + public void createInternalRepresentation(FrameBlock in, DocumentRepresentation[] internalRepresentation, int rowStart, int blk) { + super.createInternalRepresentation(in, internalRepresentation, rowStart, blk); + int endIndex = getEndIndex(in.getNumRows(), rowStart, blk); + for(int row = rowStart; row < endIndex; row++){ + DocumentRepresentation documentRepresentation = internalRepresentation[row]; + if(this.ngramType == NgramType.DOCUMENT){ + documentRepresentation.splitIntoNgrams(this.minGram, this.maxGram); + } else if (this.ngramType == NgramType.TOKEN) { + List<Token> newTokens = new ArrayList<>(); + for (Token wordToken: documentRepresentation.getTokens()) { + newTokens.addAll(splitIntoNgrams(wordToken, this.minGram, this.maxGram)); + } + documentRepresentation.tokens = newTokens; + } + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderWhitespaceSplit.java b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderWhitespaceSplit.java index c1ba7916b6..ac85c4256f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderWhitespaceSplit.java +++ b/src/main/java/org/apache/sysds/runtime/transform/tokenize/builder/TokenizerBuilderWhitespaceSplit.java @@ -22,69 +22,61 @@ package org.apache.sysds.runtime.transform.tokenize.builder; import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation; import org.apache.sysds.runtime.transform.tokenize.Token; -import org.apache.sysds.runtime.transform.tokenize.Tokenizer; -import org.apache.sysds.runtime.util.DependencyTask; import org.apache.wink.json4j.JSONException; import org.apache.wink.json4j.JSONObject; -import java.io.Serializable; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Iterator; import java.util.List; import java.util.Objects; -import java.util.concurrent.Callable; -import java.util.stream.Collectors; -import static org.apache.sysds.runtime.util.UtilFunctions.getBlockSizes; import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; public class TokenizerBuilderWhitespaceSplit extends TokenizerBuilder { - private static final long serialVersionUID = 539127244034913364L; + private static final long serialVersionUID = 539127244034913364L; - private final int[] idCols; - private final int tokenizeCol; + private final int[] idCols; + private final int tokenizeCol; - public String regex = "\\s+"; // whitespace + public String regex = "\\s+"; // whitespace - public TokenizerBuilderWhitespaceSplit(int[] idCols, int tokenizeCol, JSONObject params) throws JSONException { - if (params != null && params.has("regex")) { - this.regex = params.getString("regex"); - } - this.idCols = idCols; - this.tokenizeCol = tokenizeCol; - } + public TokenizerBuilderWhitespaceSplit(int[] idCols, int tokenizeCol, JSONObject params) throws JSONException { + if (params != null && params.has("regex")) { + this.regex = params.getString("regex"); + } + this.idCols = idCols; + this.tokenizeCol = tokenizeCol; + } - public List<Token> splitToTokens(String text) { - List<Token> tokenList = new ArrayList<>(); - if(text == null) - return tokenList; - String[] textTokens = text.split(this.regex); - int curIndex = 0; - for(String textToken: textTokens) { - if(Objects.equals(textToken, "")){ - continue; - } - int tokenIndex = text.indexOf(textToken, curIndex); - curIndex = tokenIndex; - tokenList.add(new Token(textToken, tokenIndex)); - } - return tokenList; - } + public List<Token> splitToTokens(String text) { + List<Token> tokenList = new ArrayList<>(); + if(text == null) + return tokenList; + String[] textTokens = text.split(this.regex); + int curIndex = 0; + for(String textToken: textTokens) { + if(Objects.equals(textToken, "")){ + continue; + } + int tokenIndex = text.indexOf(textToken, curIndex); + curIndex = tokenIndex; + tokenList.add(new Token(textToken, tokenIndex)); + } + return tokenList; + } - @Override - public void createInternalRepresentation(FrameBlock in, DocumentRepresentation[] internalRepresentation, int rowStart, int blk) { - int endIndex = getEndIndex(in.getNumRows(), rowStart, blk); - for (int i = rowStart; i < endIndex; i++) { - String text = in.getString(i, tokenizeCol - 1); - List<Token> tokenList = splitToTokens(text); - List<Object> keys = new ArrayList<>(); - for (Integer idCol : idCols) { - Object key = in.get(i, idCol - 1); - keys.add(key); - internalRepresentation[i] = new DocumentRepresentation(keys, tokenList); - } - } - } + @Override + public void createInternalRepresentation(FrameBlock in, DocumentRepresentation[] internalRepresentation, int rowStart, int blk) { + int endIndex = getEndIndex(in.getNumRows(), rowStart, blk); + for (int i = rowStart; i < endIndex; i++) { + String text = in.getString(i, tokenizeCol - 1); + List<Token> tokenList = splitToTokens(text); + List<Object> keys = new ArrayList<>(); + for (Integer idCol : idCols) { + Object key = in.get(i, idCol - 1); + keys.add(key); + internalRepresentation[i] = new DocumentRepresentation(keys, tokenList); + } + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTomeklinkTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTomeklinkTest.java index 457c71b6a8..c33017a471 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTomeklinkTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTomeklinkTest.java @@ -20,7 +20,6 @@ package org.apache.sysds.test.functions.builtin.part2; import org.junit.Ignore; -import org.junit.Test; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TokenizeMultithreadedTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TokenizeMultithreadedTest.java index 2b28848da8..c1867df437 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TokenizeMultithreadedTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TokenizeMultithreadedTest.java @@ -19,7 +19,6 @@ package org.apache.sysds.test.functions.transform; -import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; @@ -37,156 +36,155 @@ import org.junit.Test; import javax.json.Json; import javax.json.JsonObject; import javax.json.JsonObjectBuilder; -import java.io.IOException; public class TokenizeMultithreadedTest extends AutomatedTestBase { - private static final String TEST_DIR = "functions/transform/"; - private static final String TEST_CLASS_DIR = TEST_DIR + TokenizeMultithreadedTest.class.getSimpleName() + "/"; - - //dataset and transform tasks without missing values - private final static String DATASET = "20news/20news_subset_untokenized.csv"; - - - private final static JsonObject ngram_algo_params0 = Json.createObjectBuilder() - .add("min_gram", 2) - .add("max_gram", 3) - .add("regex", "\\W+") - .build(); - - private final static JsonObject count_out_params0 = Json.createObjectBuilder().add("sort_alpha", false).build(); - private final static JsonObject count_out_params1 = Json.createObjectBuilder().add("sort_alpha", true).build(); - - private final static JsonObject hash_out_params0 = Json.createObjectBuilder().add("num_features", 128).build(); - - public enum TokenizerBuilder { - WHITESPACE_SPLIT, - NGRAM, - } - - public enum TokenizerApplier { - COUNT, - HASH, - POSITION, - } - - @Override - public void setUp() { - TestUtils.clearAssertionInformation(); - addTestConfiguration(this.getClass().getSimpleName(), - new TestConfiguration(TEST_CLASS_DIR, this.getClass().getSimpleName(), new String[] { "R" }) ); - } - - @Test - public void testTokenizeSplitCountLong() { - runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.WHITESPACE_SPLIT,TokenizerApplier.COUNT, - 2000, false, null, count_out_params0); - } - - @Test - public void testTokenizeNgramCountLong() { - runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM, TokenizerApplier.COUNT, - 2000, false, ngram_algo_params0, count_out_params0); - } - - @Test - public void testTokenizeSplitPositionLong() { - runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.WHITESPACE_SPLIT, TokenizerApplier.POSITION, - 2000, false, null, null); - } - - @Test - public void testTokenizeNgramPositionLong() { - runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM, TokenizerApplier.POSITION, - 2000, false, ngram_algo_params0, null); - } - - @Test - public void testTokenizeSplitHashLong() { - runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.WHITESPACE_SPLIT, TokenizerApplier.HASH, - 2000, false, null, hash_out_params0); - } - - @Test - public void testTokenizeNgramHashLong() { - runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM, TokenizerApplier.HASH, - 2000, false, ngram_algo_params0, hash_out_params0); - } - @Test - public void testTokenizeSplitCountWide() { - runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.WHITESPACE_SPLIT,TokenizerApplier.POSITION, - 2000, true, null, count_out_params0); - } - - @Test - public void testTokenizeNgramCountWide() { - runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM, TokenizerApplier.POSITION, - 2000, true, ngram_algo_params0, count_out_params0); - } - - @Test - public void testTokenizeSplitHashWide() { - runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.WHITESPACE_SPLIT, TokenizerApplier.HASH, - 2000, true, null, hash_out_params0); - } - - @Test - public void testTokenizeNgramHashWide() { - runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM, TokenizerApplier.HASH, - 2000, true, ngram_algo_params0, hash_out_params0); - } - - private void runTokenizeTest(ExecMode rt, TokenizerBuilder builder, TokenizerApplier applier, - int max_tokens, boolean format_wide, JsonObject algo_params, JsonObject out_params) { - try{ - getAndLoadTestConfiguration(this.getClass().getSimpleName()); - FileFormatPropertiesCSV props = new FileFormatPropertiesCSV(); - props.setHeader(false); - FrameBlock input = FrameReaderFactory.createFrameReader(Types.FileFormat.CSV, props) - .readFrameFromHDFS(DATASET_DIR+DATASET, -1L, -1L); - String spec = createTokenizerSpec(builder, applier, format_wide, algo_params, out_params); - Tokenizer tokenizer = TokenizerFactory.createTokenizer(spec, max_tokens); - FrameBlock outS = tokenizer.tokenize(input, 1); - FrameBlock outM = tokenizer.tokenize(input, 12); - Assert.assertEquals(outS.getNumRows(), outM.getNumRows()); - Assert.assertEquals(outS.getNumColumns(), outM.getNumColumns()); - TestUtils.compareFrames(DataConverter.convertToStringFrame(outS), - DataConverter.convertToStringFrame(outM), outS.getNumRows(), outS.getNumColumns()); - - } catch (Exception ex){ - throw new RuntimeException(ex); - } - - } - - private String createTokenizerSpec(TokenizerBuilder builder, TokenizerApplier applier, boolean format_wide, JsonObject algo_params, JsonObject out_params) { - JsonObjectBuilder spec = Json.createObjectBuilder(); - switch (builder){ - case WHITESPACE_SPLIT: - spec.add("algo", "split"); - break; - case NGRAM: - spec.add("algo", "ngram"); - break; - } - switch (applier){ - case COUNT: - spec.add("out", "count"); - break; - case POSITION: - spec.add("out", "position"); - break; - case HASH: - spec.add("out", "hash"); - break; - } - if(out_params != null) - spec.add("out_params", out_params); - if(algo_params != null) - spec.add("algo_params", algo_params); - spec.add("format_wide", format_wide); - spec.add("id_cols",Json.createArrayBuilder().add(2).add(3)); - spec.add("tokenize_col", 4); - return spec.build().toString(); - } + private static final String TEST_DIR = "functions/transform/"; + private static final String TEST_CLASS_DIR = TEST_DIR + TokenizeMultithreadedTest.class.getSimpleName() + "/"; + + //dataset and transform tasks without missing values + private final static String DATASET = "20news/20news_subset_untokenized.csv"; + + + private final static JsonObject ngram_algo_params0 = Json.createObjectBuilder() + .add("min_gram", 2) + .add("max_gram", 3) + .add("regex", "\\W+") + .build(); + + private final static JsonObject count_out_params0 = Json.createObjectBuilder().add("sort_alpha", false).build(); + //private final static JsonObject count_out_params1 = Json.createObjectBuilder().add("sort_alpha", true).build(); + + private final static JsonObject hash_out_params0 = Json.createObjectBuilder().add("num_features", 128).build(); + + public enum TokenizerBuilder { + WHITESPACE_SPLIT, + NGRAM, + } + + public enum TokenizerApplier { + COUNT, + HASH, + POSITION, + } + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(this.getClass().getSimpleName(), + new TestConfiguration(TEST_CLASS_DIR, this.getClass().getSimpleName(), new String[] { "R" }) ); + } + + @Test + public void testTokenizeSplitCountLong() { + runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.WHITESPACE_SPLIT,TokenizerApplier.COUNT, + 2000, false, null, count_out_params0); + } + + @Test + public void testTokenizeNgramCountLong() { + runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM, TokenizerApplier.COUNT, + 2000, false, ngram_algo_params0, count_out_params0); + } + + @Test + public void testTokenizeSplitPositionLong() { + runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.WHITESPACE_SPLIT, TokenizerApplier.POSITION, + 2000, false, null, null); + } + + @Test + public void testTokenizeNgramPositionLong() { + runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM, TokenizerApplier.POSITION, + 2000, false, ngram_algo_params0, null); + } + + @Test + public void testTokenizeSplitHashLong() { + runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.WHITESPACE_SPLIT, TokenizerApplier.HASH, + 2000, false, null, hash_out_params0); + } + + @Test + public void testTokenizeNgramHashLong() { + runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM, TokenizerApplier.HASH, + 2000, false, ngram_algo_params0, hash_out_params0); + } + @Test + public void testTokenizeSplitCountWide() { + runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.WHITESPACE_SPLIT,TokenizerApplier.POSITION, + 2000, true, null, count_out_params0); + } + + @Test + public void testTokenizeNgramCountWide() { + runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM, TokenizerApplier.POSITION, + 2000, true, ngram_algo_params0, count_out_params0); + } + + @Test + public void testTokenizeSplitHashWide() { + runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.WHITESPACE_SPLIT, TokenizerApplier.HASH, + 2000, true, null, hash_out_params0); + } + + @Test + public void testTokenizeNgramHashWide() { + runTokenizeTest(ExecMode.SINGLE_NODE, TokenizerBuilder.NGRAM, TokenizerApplier.HASH, + 2000, true, ngram_algo_params0, hash_out_params0); + } + + private void runTokenizeTest(ExecMode rt, TokenizerBuilder builder, TokenizerApplier applier, + int max_tokens, boolean format_wide, JsonObject algo_params, JsonObject out_params) { + try{ + getAndLoadTestConfiguration(this.getClass().getSimpleName()); + FileFormatPropertiesCSV props = new FileFormatPropertiesCSV(); + props.setHeader(false); + FrameBlock input = FrameReaderFactory.createFrameReader(Types.FileFormat.CSV, props) + .readFrameFromHDFS(DATASET_DIR+DATASET, -1L, -1L); + String spec = createTokenizerSpec(builder, applier, format_wide, algo_params, out_params); + Tokenizer tokenizer = TokenizerFactory.createTokenizer(spec, max_tokens); + FrameBlock outS = tokenizer.tokenize(input, 1); + FrameBlock outM = tokenizer.tokenize(input, 12); + Assert.assertEquals(outS.getNumRows(), outM.getNumRows()); + Assert.assertEquals(outS.getNumColumns(), outM.getNumColumns()); + TestUtils.compareFrames(DataConverter.convertToStringFrame(outS), + DataConverter.convertToStringFrame(outM), outS.getNumRows(), outS.getNumColumns()); + + } catch (Exception ex){ + throw new RuntimeException(ex); + } + + } + + private String createTokenizerSpec(TokenizerBuilder builder, TokenizerApplier applier, boolean format_wide, JsonObject algo_params, JsonObject out_params) { + JsonObjectBuilder spec = Json.createObjectBuilder(); + switch (builder){ + case WHITESPACE_SPLIT: + spec.add("algo", "split"); + break; + case NGRAM: + spec.add("algo", "ngram"); + break; + } + switch (applier){ + case COUNT: + spec.add("out", "count"); + break; + case POSITION: + spec.add("out", "position"); + break; + case HASH: + spec.add("out", "hash"); + break; + } + if(out_params != null) + spec.add("out_params", out_params); + if(algo_params != null) + spec.add("algo_params", algo_params); + spec.add("format_wide", format_wide); + spec.add("id_cols",Json.createArrayBuilder().add(2).add(3)); + spec.add("tokenize_col", 4); + return spec.build().toString(); + } }
