This is an automated email from the ASF dual-hosted git repository. sergeykamov pushed a commit to branch 261tmp in repository https://gitbox.apache.org/repos/asf/incubator-nlpcraft.git
commit 46a636be94bd9de66eb66cc3cd510db8dc12a16b Author: Sergey Kamov <[email protected]> AuthorDate: Sun Mar 7 14:43:53 2021 +0300 WIP. --- .../apache/nlpcraft/common/nlp/NCNlpSentence.scala | 176 +++++++------ .../org/apache/nlpcraft/common/nlp/SyTest.java | 241 ++++++++++++++++++ .../nlpcraft/common/util/NCComboRecursiveTask.java | 230 +++++++++++++++++ .../nlpcraft/probe/mgrs/nlp/enrichers/SyTest.java | 278 +++++++++++++++++++++ 4 files changed, 850 insertions(+), 75 deletions(-) diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/common/nlp/NCNlpSentence.scala b/nlpcraft/src/main/scala/org/apache/nlpcraft/common/nlp/NCNlpSentence.scala index f530327..9d9cb98 100644 --- a/nlpcraft/src/main/scala/org/apache/nlpcraft/common/nlp/NCNlpSentence.scala +++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/common/nlp/NCNlpSentence.scala @@ -20,11 +20,13 @@ package org.apache.nlpcraft.common.nlp import com.typesafe.scalalogging.LazyLogging import org.apache.nlpcraft.common.NCE import org.apache.nlpcraft.common.nlp.pos.NCPennTreebank +import org.apache.nlpcraft.common.util.NCComboRecursiveTask import org.apache.nlpcraft.model.NCModel -import java.io.{Serializable ⇒ JSerializable} +import java.io.{Serializable => JSerializable} import java.util -import java.util.{Collections, List ⇒ JList} +import java.util.concurrent.ForkJoinPool +import java.util.{Collections, Comparator, List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.collection.{Map, Seq, Set, mutable} @@ -41,7 +43,6 @@ object NCNlpSentence extends LazyLogging { require(start <= end) private def in(i: Int): Boolean = i >= start && i <= end - def intersect(id: String, start: Int, end: Int): Boolean = id == this.id && (in(start) || in(end)) } @@ -72,7 +73,7 @@ object NCNlpSentence extends LazyLogging { noteLinks ++= (for ((name, idxs) ← names.asScala.zip(idxsSeq.asScala.map(_.asScala))) yield NoteLink(name, idxs.sorted) - ) + ) } if (n.contains("subjnotes")) add("subjnotes", "subjindexes") @@ -410,8 +411,7 @@ object NCNlpSentence extends LazyLogging { "stopWord" → stop, "bracketed" → false, "direct" → direct, - "dict" → (if (nsCopyToks.size == 1) nsCopyToks.head.getNlpNote.data[Boolean]("dict") - else false), + "dict" → (if (nsCopyToks.size == 1) nsCopyToks.head.getNlpNote.data[Boolean]("dict") else false), "english" → nsCopyToks.forall(_.getNlpNote.data[Boolean]("english")), "swear" → nsCopyToks.exists(_.getNlpNote.data[Boolean]("swear")) ) @@ -458,8 +458,7 @@ object NCNlpSentence extends LazyLogging { var fixed = idxs history.foreach { - case (idxOld, idxNew) ⇒ fixed = fixed.map(_.map(i ⇒ if (i == idxOld) idxNew - else i).distinct) + case (idxOld, idxNew) ⇒ fixed = fixed.map(_.map(i ⇒ if (i == idxOld) idxNew else i).distinct) } if (fixed.forall(_.size == 1)) @@ -522,9 +521,9 @@ object NCNlpSentence extends LazyLogging { val res = fixIndexesReferences("nlpcraft:relation", "indexes", "note", ns, history) && - fixIndexesReferences("nlpcraft:limit", "indexes", "note", ns, history) && - fixIndexesReferencesList("nlpcraft:sort", "subjindexes", "subjnotes", ns, history) && - fixIndexesReferencesList("nlpcraft:sort", "byindexes", "bynotes", ns, history) + fixIndexesReferences("nlpcraft:limit", "indexes", "note", ns, history) && + fixIndexesReferencesList("nlpcraft:sort", "subjindexes", "subjnotes", ns, history) && + fixIndexesReferencesList("nlpcraft:sort", "byindexes", "bynotes", ns, history) if (res) { // Validation (all indexes calculated well) @@ -590,8 +589,7 @@ object NCNlpSentence extends LazyLogging { if (lastPhase) dropAbstract(mdl, ns) - if (collapseSentence(ns, getNotNlpNotes(ns).map(_.noteType).distinct)) Some(ns) - else None + if (collapseSentence(ns, getNotNlpNotes(ns).map(_.noteType).distinct)) Some(ns) else None } // Always deletes `similar` notes. @@ -599,27 +597,27 @@ object NCNlpSentence extends LazyLogging { // We keep only one variant - with `best` direct and sparsity parameters, // other variants for these words are redundant. val redundant: Seq[NCNlpSentenceNote] = - thisSen.flatten.filter(!_.isNlp).distinct. - groupBy(_.getKey()). - map(p ⇒ p._2.sortBy(p ⇒ - ( - // System notes don't have such flags. - if (p.isUser) { - if (p.isDirect) - 0 + thisSen.flatten.filter(!_.isNlp).distinct. + groupBy(_.getKey()). + map(p ⇒ p._2.sortBy(p ⇒ + ( + // System notes don't have such flags. + if (p.isUser) { + if (p.isDirect) + 0 + else + 1 + } else - 1 - } - else - 0, - if (p.isUser) - p.sparsity - else - 0 - ) - )). - flatMap(_.drop(1)). - toSeq + 0, + if (p.isUser) + p.sparsity + else + 0 + ) + )). + flatMap(_.drop(1)). + toSeq redundant.foreach(thisSen.removeNote) @@ -642,62 +640,91 @@ object NCNlpSentence extends LazyLogging { val key = PartKey(note, thisSen) val delCombOthers = - delCombs.filter(_ != note).flatMap(n ⇒ if (getPartKeys(n).contains(key)) Some(n) - else None) + delCombs.filter(_ != note).flatMap(n ⇒ if (getPartKeys(n).contains(key)) Some(n) else None) - if (delCombOthers.exists(o ⇒ noteWordsIdxs == o.wordIndexes.toSet)) Some(note) - else None + if (delCombOthers.exists(o ⇒ noteWordsIdxs == o.wordIndexes.toSet)) Some(note) else None }) delCombs = delCombs.filter(p ⇒ !swallowed.contains(p)) addDeleted(thisSen, thisSen, swallowed) swallowed.foreach(thisSen.removeNote) - val toksByIdx: Seq[Set[NCNlpSentenceNote]] = + val toksByIdx: Seq[Seq[NCNlpSentenceNote]] = delCombs.flatMap(note ⇒ note.wordIndexes.map(_ → note)). groupBy { case (idx, _) ⇒ idx }. - map { case (_, seq) ⇒ seq.map { case (_, note) ⇒ note }.toSet }. + map { case (_, seq) ⇒ seq.map { case (_, note) ⇒ note } }. toSeq.sortBy(-_.size) - val minDelSize = if (toksByIdx.isEmpty) 1 else toksByIdx.map(_.size).max - 1 +// val toksByIdx1 = +// delCombs.flatMap(note ⇒ note.wordIndexes.map(_ → note)). +// groupBy { case (idx, _) ⇒ idx }. +// map { case (idx, seq) ⇒ idx → seq.map { case (_, note) ⇒ note } }. +// toSeq.sortBy(_._2.size) +// +// toksByIdx.foreach{ case (seq) ⇒ +// println(s"toksByIdx seq=${seq.map(i ⇒ s"${i.noteType} ${i.wordIndexes.mkString(",")}").mkString(" | ")}") +// } + +// toksByIdx1.sortBy(_._1).foreach{ case (i, seq) ⇒ +// println(s"toksByIdx1 ${i} seq=${seq.map(i ⇒ s"${i.noteType} ${i.wordIndexes.mkString(",")}").mkString(" | ")}") +// } + + val dict = mutable.HashMap.empty[String, NCNlpSentenceNote] + + var i = 'A' + + val converted: Seq[Seq[String]] = + toksByIdx.map(seq ⇒ { + seq.map( + n ⇒ { + val s = s"$i" + + i = (i.toInt + 1).toChar + + dict += s → n + + s + } + ) + }) + + //val minDelSize = if (toksByIdx.isEmpty) 1 else toksByIdx.map(_.size).max - 1 var sens = if (delCombs.nonEmpty) { - val deleted = mutable.ArrayBuffer.empty[Set[NCNlpSentenceNote]] + val p = new ForkJoinPool() + + val tmp = NCComboRecursiveTask.findCombinations( + converted.map(_.asJava).asJava, + new Comparator[String]() { + override def compare(n1: String, n2: String): Int = n1.compareTo(n2) + }, + p + ) + + p.shutdown() + + val seq1 = tmp.asScala.map(_.asScala.map(dict)) val sens = - (minDelSize to delCombs.size). - flatMap(i ⇒ - delCombs.combinations(i). - filter(delComb ⇒ - !toksByIdx.exists( - rec ⇒ - rec.size - delCombs.size <= 1 && - rec.count(note ⇒ !delComb.contains(note)) > 1 - ) - ) - ). - sortBy(_.size). - map(_.toSet). - flatMap(delComb ⇒ - // Already processed with less subset of same deleted tokens. - if (!deleted.exists(_.subsetOf(delComb))) { - val nsClone = thisSen.clone() - - // Saves deleted notes for sentence and their tokens. - addDeleted(thisSen, nsClone, delComb) - delComb.foreach(nsClone.removeNote) - - // Has overlapped notes for some tokens. - require(!nsClone.exists(_.count(!_.isNlp) > 1)) - - deleted += delComb - - collapse0(nsClone) - } - else - None - ) + seq1. + flatMap(p ⇒ { + val delComb: Seq[NCNlpSentenceNote] = p + + val nsClone = thisSen.clone() + + // Saves deleted notes for sentence and their tokens. + addDeleted(thisSen, nsClone, delComb) + delComb.foreach(nsClone.removeNote) + + // Has overlapped notes for some tokens. + require( + !nsClone.exists(_.count(!_.isNlp) > 1), + s"Invalid notes: ${nsClone.filter(_.count(!_.isNlp) > 1).mkString("|")}" + ) + + collapse0(nsClone) + }) // It removes sentences which have only one difference - 'direct' flag of their user tokens. // `Direct` sentences have higher priority. @@ -719,8 +746,7 @@ object NCNlpSentence extends LazyLogging { p.clone().filter(_._1 != "direct") ) - (Key(get(sysNotes), get(userNotes)), sen, nlpNotes.map(p ⇒ if (p.isDirect) 0 - else 1).sum) + (Key(get(sysNotes), get(userNotes)), sen, nlpNotes.map(p ⇒ if (p.isDirect) 0 else 1).sum) }). foreach { case (key, sen, directCnt) ⇒ m.get(key) match { diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/common/nlp/SyTest.java b/nlpcraft/src/main/scala/org/apache/nlpcraft/common/nlp/SyTest.java new file mode 100644 index 0000000..db5f657 --- /dev/null +++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/common/nlp/SyTest.java @@ -0,0 +1,241 @@ +package org.apache.nlpcraft.common.nlp; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.RecursiveTask; +import java.util.stream.IntStream; + +import static java.util.stream.Collectors.toList; +import static java.util.Arrays.asList; +import static java.util.stream.Collectors.toUnmodifiableList; + +public class SyTest { + public static class ComboSearch extends RecursiveTask<List<Long>> { + private static final long THRESHOLD = (long)Math.pow(2, 20); + + private final long lo; + + private final long hi; + + private final long[] wordBits; + + private final int[] wordCounts; + + public ComboSearch( + long lo, + long hi, + long[] words, + int[] wordCounts + ) { + this.lo = lo; + this.hi = hi; + this.wordBits = words; + this.wordCounts = wordCounts; + } + + public static <T> List<List<T>> findCombos(List<List<T>> inp, ForkJoinPool pool) { + List<List<T>> uniqueInp = inp.stream() + .filter(row -> inp.stream().noneMatch(it -> it != row && it.containsAll(row))) + .map(i -> i.stream().distinct().sorted().collect(toList())) + .collect(toList()); + + // Build dictionary of unique words. + List<T> dict = uniqueInp.stream() + .flatMap(Collection::stream) + .distinct() + .sorted() + .collect(toList()); + + System.out.println("uniqueInp="+uniqueInp.size()); + System.out.println("dict="+dict.size()); + + if (dict.size() > Long.SIZE) { + // Note: Power set of 64 words results in 9223372036854775807 combinations. + throw new IllegalArgumentException("Can handle more than " + Long.SIZE + " unique words in the dictionary."); + } + + // Convert words to bitmasks (each bit corresponds to an index in the dictionary). + long[] wordBits = uniqueInp.stream() + .sorted(Comparator.comparingInt(List::size)) + .mapToLong(row -> wordsToBits(row, dict)) + .toArray(); + + // Cache words count per row. + int[] wordCounts = uniqueInp.stream() + .sorted(Comparator.comparingInt(List::size)) + .mapToInt(List::size) + .toArray(); + + // Prepare Fork/Join task to iterate over the power set of all combinations. + int lo = 1; + long hi = (long)Math.pow(2, dict.size()); + + ComboSearch task = new ComboSearch( + lo, + hi, + wordBits, + wordCounts + ); + + return pool.invoke(task).stream() + .map(bits -> bitsToWords(bits, dict)) + .collect(toList()); + } + + @Override + protected List<Long> compute() { + if (hi - lo <= THRESHOLD) { + return computeLocal(); + } else { + return forkJoin(); + } + } + + private List<Long> computeLocal() { + List<Long> result = new ArrayList<>(); + + for (long comboBits = lo; comboBits < hi; comboBits++) { + boolean match = true; + + // For each input row we check if subtracting the current combination of words + // from the input row would give us the expected result. + for (int j = 0; j < wordBits.length; j++) { + // Get bitmask of how many words can be subtracted from the row. + long commonBits = wordBits[j] & comboBits; + + int wordsToRemove = Long.bitCount(commonBits); + + // Check if there is more than 1 word remaining after subtraction. + if (wordCounts[j] - wordsToRemove > 1) { + // Skip this combination. + match = false; + + break; + } + } + + if (match && !includes(comboBits, result)) { + result.add(comboBits); + } + } + + return result; + } + + private List<Long> forkJoin() { + long mid = lo + hi >>> 1L; + + ComboSearch t1 = new ComboSearch(lo, mid, wordBits, wordCounts); + ComboSearch t2 = new ComboSearch(mid, hi, wordBits, wordCounts); + + t2.fork(); + + return merge(t1.compute(), t2.join()); + } + + private List<Long> merge(List<Long> l1, List<Long> l2) { + if (l1.isEmpty()) { + return l2; + } else if (l2.isEmpty()) { + return l1; + } + + int size1 = l1.size(); + int size2 = l2.size(); + + if (size1 == 1 && size2 > 1 || size2 == 1 && size1 > 1) { + // Minor optimization in case if one of the lists has only one element. + List<Long> list = size1 == 1 ? l2 : l1; + Long val = size1 == 1 ? l1.get(0) : l2.get(0); + + if (!includes(val, list)) { + list.add(val); + } + + return list; + } else { + List<Long> result = new ArrayList<>(size1 + size2); + + for (int i = 0, max = Math.max(size1, size2); i < max; i++) { + Long v1 = i < size1 ? l1.get(i) : null; + Long v2 = i < size2 ? l2.get(i) : null; + + if (v1 != null && v2 != null) { + if (containsAllBits(v1, v2)) { + v1 = null; + } else if (containsAllBits(v2, v1)) { + v2 = null; + } + } + + if (v1 != null && !includes(v1, result)) { + result.add(v1); + } + + if (v2 != null && !includes(v2, result)) { + result.add(v2); + } + } + + return result; + } + } + + private static boolean includes(long bits, List<Long> allBits) { + for (long existing : allBits) { + if (containsAllBits(bits, existing)) { + return true; + } + } + + return false; + } + + private static boolean containsAllBits(long bitSet1, long bitSet2) { + return (bitSet1 & bitSet2) == bitSet2; + } + + private static <T> long wordsToBits(List<T> words, List<T> dict) { + long bits = 0; + + for (int i = 0; i < dict.size(); i++) { + if (words.contains(dict.get(i))) { + bits |= 1L << i; + } + } + + return bits; + } + + private static <T> List<T> bitsToWords(long bits, List<T> dict) { + List<T> words = new ArrayList<>(Long.bitCount(bits)); + + for (int i = 0; i < dict.size(); i++) { + if ((bits & 1L << i) != 0) { + words.add(dict.get(i)); + } + } + + return words; + } + } + + public static void main(String[] args) throws InterruptedException { + List<List<String>> words = IntStream.range(0, 35) + .mapToObj(i -> IntStream.range(0, i + 1).mapToObj(String::valueOf).collect(toUnmodifiableList())) + .collect(toUnmodifiableList()); + + long t = System.currentTimeMillis(); + + ForkJoinPool forkJoinPool = new ForkJoinPool(); + final List<List<String>> combos = ComboSearch.findCombos(words, forkJoinPool); + + System.out.println("size=" + combos.size()); + System.out.println("time=" + (System.currentTimeMillis() - t)); + } +} \ No newline at end of file diff --git a/nlpcraft/src/main/scala/org/apache/nlpcraft/common/util/NCComboRecursiveTask.java b/nlpcraft/src/main/scala/org/apache/nlpcraft/common/util/NCComboRecursiveTask.java new file mode 100644 index 0000000..017c10e --- /dev/null +++ b/nlpcraft/src/main/scala/org/apache/nlpcraft/common/util/NCComboRecursiveTask.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nlpcraft.common.util; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.RecursiveTask; +import java.util.stream.Collectors; + +import static java.util.stream.Collectors.toList; + +public class NCComboRecursiveTask extends RecursiveTask<List<Long>> { + private static final long THRESHOLD = (long)Math.pow(2, 20); + + private final long lo; + private final long hi; + private final long[] wordBits; + private final int[] wordCounts; + + private NCComboRecursiveTask(long lo, long hi, long[] wordBits, int[] wordCounts) { + this.lo = lo; + this.hi = hi; + this.wordBits = wordBits; + this.wordCounts = wordCounts; + } + + public static <T> List<List<T>> findCombinations(List<List<T>> inp, Comparator<T> comparator, ForkJoinPool pool) { + List<List<T>> uniqueInp = inp.stream() + .filter(row -> inp.stream().noneMatch(it -> !it.equals(row) && it.containsAll(row))) + .map(i -> i.stream().distinct().sorted(comparator).collect(toList())) + .collect(toList()); + + + System.out.println("!!!"); + for (List<T> ts : uniqueInp) { + System.out.println("!!!ts="); + System.out.println(ts.stream().map(Object::toString).collect(Collectors.joining("\n"))); + } + System.out.println("!!!"); + + // Build dictionary of unique words. + List<T> dict = uniqueInp.stream() + .flatMap(Collection::stream) + .distinct() + .sorted(comparator) + .collect(toList()); + + System.out.println("dict="); + System.out.println(dict.stream().map(Object::toString).collect(Collectors.joining("\n"))); + System.out.println(); + + if (dict.size() > Long.SIZE) { + // Note: Power set of 64 words results in 9223372036854775807 combinations. + throw new IllegalArgumentException("Dictionary is too long: " + dict.size()); + } + + // Convert words to bitmasks (each bit corresponds to an index in the dictionary). + long[] wordBits = uniqueInp.stream() + .sorted(Comparator.comparingInt(List::size)) + .mapToLong(row -> wordsToBits(row, dict)) + .toArray(); + + // Cache words count per row. + int[] wordCounts = uniqueInp.stream().sorted(Comparator.comparingInt(List::size)).mapToInt(List::size).toArray(); + + // Prepare Fork/Join task to iterate over the power set of all combinations. + int lo = 1; + long hi = (long)Math.pow(2, dict.size()); + + NCComboRecursiveTask task = new NCComboRecursiveTask(lo, hi, wordBits, wordCounts); + + return pool.invoke(task).stream().map(bits -> bitsToWords(bits, dict)).collect(toList()); + } + + @Override + protected List<Long> compute() { + return hi - lo <= THRESHOLD ? computeLocal() : forkJoin(); + } + + private List<Long> computeLocal() { + List<Long> result = new ArrayList<>(); + + for (long comboBits = lo; comboBits < hi; comboBits++) { + boolean match = true; + + // For each input row we check if subtracting the current combination of words + // from the input row would give us the expected result. + for (int j = 0; j < wordBits.length; j++) { + // Get bitmask of how many words can be subtracted from the row. + long commonBits = wordBits[j] & comboBits; + + int wordsToRemove = Long.bitCount(commonBits); + + // Check if there is more than 1 word remaining after subtraction. + if (wordCounts[j] - wordsToRemove > 1) { + // Skip this combination. + match = false; + + break; + } + } + + if (match && !includes(comboBits, result)) { + result.add(comboBits); + } + } + + return result; + } + + private List<Long> forkJoin() { + long mid = lo + hi >>> 1L; + + NCComboRecursiveTask t1 = new NCComboRecursiveTask(lo, mid, wordBits, wordCounts); + NCComboRecursiveTask t2 = new NCComboRecursiveTask(mid, hi, wordBits, wordCounts); + + t2.fork(); + + return merge(t1.compute(), t2.join()); + } + + private List<Long> merge(List<Long> l1, List<Long> l2) { + if (l1.isEmpty()) { + return l2; + } + else if (l2.isEmpty()) { + return l1; + } + + int size1 = l1.size(); + int size2 = l2.size(); + + if (size1 == 1 && size2 > 1 || size2 == 1 && size1 > 1) { + // Minor optimization in case if one of the lists has only one element. + List<Long> list = size1 == 1 ? l2 : l1; + Long val = size1 == 1 ? l1.get(0) : l2.get(0); + + if (!includes(val, list)) { + list.add(val); + } + + return list; + } + else { + List<Long> result = new ArrayList<>(size1 + size2); + + for (int i = 0, max = Math.max(size1, size2); i < max; i++) { + Long v1 = i < size1 ? l1.get(i) : null; + Long v2 = i < size2 ? l2.get(i) : null; + + if (v1 != null && v2 != null) { + if (containsAllBits(v1, v2)) { + v1 = null; + } + else if (containsAllBits(v2, v1)) { + v2 = null; + } + } + + if (v1 != null && !includes(v1, result)) { + result.add(v1); + } + + if (v2 != null && !includes(v2, result)) { + result.add(v2); + } + } + + return result; + } + } + + private static boolean includes(long bits, List<Long> allBits) { + for (int i = 0, size = allBits.size(); i < size; i++) { + long existing = allBits.get(i); + + if (containsAllBits(bits, existing)) { + return true; + } + } + + return false; + } + + private static boolean containsAllBits(long bitSet1, long bitSet2) { + return (bitSet1 & bitSet2) == bitSet2; + } + + private static <T> long wordsToBits(List<T> words, List<T> dict) { + long bits = 0; + + for (int i = 0; i < dict.size(); i++) { + if (words.contains(dict.get(i))) { + bits |= 1L << i; + } + } + + return bits; + } + + private static <T> List<T> bitsToWords(long bits, List<T> dict) { + List<T> words = new ArrayList<>(Long.bitCount(bits)); + + for (int i = 0; i < dict.size(); i++) { + if ((bits & 1L << i) != 0) { + words.add(dict.get(i)); + } + } + + return words; + } +} diff --git a/nlpcraft/src/test/scala/org/apache/nlpcraft/probe/mgrs/nlp/enrichers/SyTest.java b/nlpcraft/src/test/scala/org/apache/nlpcraft/probe/mgrs/nlp/enrichers/SyTest.java new file mode 100644 index 0000000..69d9ab4 --- /dev/null +++ b/nlpcraft/src/test/scala/org/apache/nlpcraft/probe/mgrs/nlp/enrichers/SyTest.java @@ -0,0 +1,278 @@ +package org.apache.nlpcraft.probe.mgrs.nlp.enrichers; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.NavigableSet; +import java.util.Optional; +import java.util.Set; +import java.util.TreeSet; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.stream.Collectors.toList; +import static java.util.stream.Collectors.toSet; + +public class SyTest { + public static void main(String[] args) { +// List<List<String>> words = asList( +// asList("A", "B", "C"), +// asList("B", "C", "D"), +// asList("B", "D") +// ); + List<List<String>> words = asList( + asList("A", "B"), + asList("C", "B"), + asList("D", "E"), + asList("D", "F"), + asList("G", "H"), + asList("I", "H"), + asList("J", "K"), + asList("L", "K"), + asList("M", "N"), + asList("M", "O"), + asList("P", "Q"), + asList("P", "R"), + asList("S", "T"), + asList("S", "U"), + asList("V", "W"), + asList("X", "W") + , + asList("Y", "Z"), + asList("A1", "A2"), + asList("A3", "A3"), + asList("A4", "A5", "A6") + ); + + System.out.println( + "Dictionary size:" + + words.stream() + .flatMap(Collection::stream) + .distinct() + .count() + ); + + System.out.println("===== Performance ====="); + + for (int i = 0; i < 1; i++) { + long t = System.currentTimeMillis(); + + Set<Set<String>> combos = findCombos(words); + + + + System.out.println("Iteration " + i + " Time: " + (System.currentTimeMillis() - t) + ", resCnt=" + combos.size()); + } + + if (true) { + return; + } + + Set<Set<String>> combos = findCombos(words); + + System.out.println(); + System.out.println("===== Result ====="); + System.out.println("Total combos: " + combos.size()); + System.out.println(); +// combos.stream() +// .sorted(Comparator.comparing(Collection::size)) +// .forEach(combo -> +// print(words, combo) +// ); + } + + public static <T extends Comparable<T>> Set<Set<T>> findCombos(List<List<T>> inp) { + + + List<List<T>> uniqueInp = inp.stream() + .filter(row -> inp.stream().noneMatch(it -> it != row && it.containsAll(row))) + .map(i -> i.stream().distinct().sorted().collect(toList())) + .collect(toList()); + + // Build dictionary of unique words. + List<T> dict = uniqueInp.stream() + .flatMap(Collection::stream) + .distinct() + .sorted() + .collect(toList()); + + if (dict.size() > Integer.SIZE) { + // Note: Power set of 32 words results in 4294967296 combinations. + throw new IllegalArgumentException("Can handle more than " + Integer.SIZE + " unique words in the dictionary."); + } + + // Convert words to bitmasks (each bit corresponds to an index in the dictionary). + int[] wordBits = uniqueInp.stream() + .sorted(Comparator.comparingInt(List::size)) + .mapToInt(row -> wordsToBits(row, dict)) + .toArray(); + + // Cache words count per row. + int[] wordCounts = uniqueInp.stream() + .sorted(Comparator.comparingInt(List::size)) + .mapToInt(List::size) + .toArray(); + + int min = 1; + int max = (int)Math.pow(2, dict.size()) - 1; + + int batchFactor = 100; + int threads = 13; + + ExecutorService pool = Executors.newFixedThreadPool(threads); + CountDownLatch cdl = new CountDownLatch(batchFactor); + + int divRes = max / batchFactor; + int divRest = max % batchFactor; + + int to = 0; + + List<Integer> result = new CopyOnWriteArrayList<>(); + + for (int k = 0; k < batchFactor; k++) { + to += divRes; + + if (k == divRes - 1) { + to += divRest; + } + + int toFinal = to; + int fromFinal = min + k * divRes; + + pool.execute( + () -> { + List<Integer> locRes = new ArrayList<>(); + + for (int comboBits = fromFinal; comboBits < toFinal; comboBits++) { + boolean match = true; + + // For each input row we check if subtracting the current combination of words + // from the input row would give us the expected result. + for (int j = 0; j < wordBits.length; j++) { + // Get bitmask of how many words can be subtracted from the row. + int commonBits = wordBits[j] & comboBits; + + int wordsToRemove = Integer.bitCount(commonBits); + + // Check if there are more than 1 word remaining after subtraction. + if (wordCounts[j] - wordsToRemove > 1) { + // Skip this combination. + match = false; + + break; + } + } + + if (match && !includes(comboBits, locRes)) { + locRes.add(comboBits); + } + } + + result.addAll(locRes); + + cdl.countDown(); + } + ); + } + +// Iterate over the power set. + + //pool.shutdown(); + try { + cdl.await(Long.MAX_VALUE, TimeUnit.MILLISECONDS); + //pool.awaitTermination(Long.MAX_VALUE, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + // Convert found results from bitmasks back to words. + TreeSet<Set<T>> treeSet = new TreeSet<>(Comparator.comparingInt(Set::size)); + + treeSet.addAll(result.stream().map(bits -> bitsToWords(bits, dict)).collect(toSet())); + + Set<Set<T>> normCombs = new HashSet<>(); + + for (Set<T> set : treeSet) { + boolean b = true; + + for (Set<T> added : normCombs) { + if (added.containsAll(set)) { + b = false; + + break; + } + } + + if (b) { + normCombs.add(set); + } + } + + return normCombs; + + } + + private static <T> Set<Set<T>> squeeze(Set<Set<T>> combs) { + Set<Set<T>> normCombs = new HashSet<>(); + + combs.stream().sorted(Comparator.comparingInt(Set::size)).forEach(comb -> { + // Skips already added shorter variants. + if (normCombs.stream().filter(comb::containsAll).findAny().isEmpty()) { + normCombs.add(comb); + } + }); + + return normCombs; + } + + + private static boolean includes(int bits, List<Integer> allBits) { + for (int existing : allBits) { + if ((bits & existing) == existing) { + return true; + } + } + + return false; + } + + private static <T> int wordsToBits(List<T> words, List<T> dict) { + int bits = 0; + + for (int i = 0; i < dict.size(); i++) { + if (words.contains(dict.get(i))) { + bits |= 1 << i; + } + } + + return bits; + } + + private static <T> Set<T> bitsToWords(int bits, List<T> dict) { + Set<T> words = new HashSet<>(Integer.bitCount(bits)); + + for (int i = 0; i < dict.size(); i++) { + if ((bits & 1 << i) != 0) { + words.add(dict.get(i)); + } + } + + return words; + } + + private static void print(List<List<String>> inp, List<String> combo) { + System.out.println("==== " + combo + "(" + combo.size() + ')'); + inp.stream().forEach(row -> { + Set<String> s = new TreeSet<>(row); + s.removeAll(combo); + System.out.println(s); + }); + } +}
