HIVEMALL-130: Support user dictionary in `tokenize_ja`
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/07a7d51b Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/07a7d51b Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/07a7d51b Branch: refs/heads/dev/v0.4.2 Commit: 07a7d51beeb4bd31a3c6202c0de68486a43e5caf Parents: e1df050 Author: Takuya Kitazawa <k.tak...@gmail.com> Authored: Wed Sep 20 15:15:17 2017 +0900 Committer: Takuya Kitazawa <tak...@apache.org> Committed: Fri Sep 22 15:49:02 2017 +0900 ---------------------------------------------------------------------- .../java/hivemall/utils/hadoop/HiveUtils.java | 10 + .../main/java/hivemall/utils/io/HttpUtils.java | 51 +++++ .../main/java/hivemall/utils/io/IOUtils.java | 28 +++ .../hivemall/utils/io/LimitedInputStream.java | 87 ++++++++ .../utils/io/LimitedInputStreamTest.java | 92 ++++++++ .../hivemall/nlp/tokenizer/KuromojiUDF.java | 163 +++++++++++--- .../hivemall/nlp/tokenizer/KuromojiUDFTest.java | 210 ++++++++++++++++--- 7 files changed, 583 insertions(+), 58 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/07a7d51b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index c21a1d9..ad0dac6 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -27,6 +27,7 @@ import static hivemall.HivemallConstants.INT_TYPE_NAME; import static hivemall.HivemallConstants.SMALLINT_TYPE_NAME; import static hivemall.HivemallConstants.STRING_TYPE_NAME; import static hivemall.HivemallConstants.TINYINT_TYPE_NAME; +import static hivemall.HivemallConstants.VOID_TYPE_NAME; import java.util.Arrays; import java.util.BitSet; @@ -170,6 +171,11 @@ public final class HiveUtils { return STRING_TYPE_NAME.equals(typeName); } + public static boolean isVoidOI(@Nonnull final ObjectInspector oi) { + String typeName = oi.getTypeName(); + return VOID_TYPE_NAME.equals(typeName); + } + public static boolean isIntOI(@Nonnull final ObjectInspector oi) { String typeName = oi.getTypeName(); return INT_TYPE_NAME.equals(typeName); @@ -275,6 +281,10 @@ public final class HiveUtils { } } + public static boolean isConstListOI(@Nonnull final ObjectInspector oi) { + return ObjectInspectorUtils.isConstantObjectInspector(oi) && isListOI(oi); + } + public static boolean isConstString(@Nonnull final ObjectInspector oi) { return ObjectInspectorUtils.isConstantObjectInspector(oi) && isStringOI(oi); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/07a7d51b/core/src/main/java/hivemall/utils/io/HttpUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/io/HttpUtils.java b/core/src/main/java/hivemall/utils/io/HttpUtils.java new file mode 100644 index 0000000..6994cfe --- /dev/null +++ b/core/src/main/java/hivemall/utils/io/HttpUtils.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.io; + +import java.io.IOException; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.net.URLConnection; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class HttpUtils { + + private HttpUtils() {} + + @Nonnull + public static HttpURLConnection getHttpURLConnection(@Nonnull String urlStr) + throws IllegalArgumentException, IOException { + if (!urlStr.startsWith("http://") && !urlStr.startsWith("https://")) { + throw new IllegalArgumentException("Unexpected url: " + urlStr); + } + URL url = new URL(urlStr); + URLConnection conn = url.openConnection(); + return (HttpURLConnection) conn; + } + + @Nonnull + public static InputStream getLimitedInputStream(@Nonnull HttpURLConnection conn, + @Nonnegative long size) throws IOException { + InputStream is = conn.getInputStream(); + return new LimitedInputStream(is, size); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/07a7d51b/core/src/main/java/hivemall/utils/io/IOUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/io/IOUtils.java b/core/src/main/java/hivemall/utils/io/IOUtils.java index 1802dfc..2aa398b 100644 --- a/core/src/main/java/hivemall/utils/io/IOUtils.java +++ b/core/src/main/java/hivemall/utils/io/IOUtils.java @@ -33,6 +33,8 @@ import java.io.InputStreamReader; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStream; +import java.io.PushbackInputStream; +import java.util.zip.GZIPInputStream; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -129,6 +131,32 @@ public final class IOUtils { return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4 << 0)); } + /** + * Look ahead InputStream and decompress it as GZIPInputStream if needed + * + * @link https://stackoverflow.com/a/4818946 + */ + @Nonnull + public static InputStream decodeInputStream(@Nonnull final InputStream is) throws IOException { + final PushbackInputStream pb = new PushbackInputStream(is, 2); + + // look ahead + final byte[] signature = new byte[2]; + final int nread = pb.read(signature); + // If no byte is available because the stream is at the end of the file, the value -1 is returned; + // otherwise, at least one byte is read and stored into b. + if (nread > 0) {// may be -1 (EOF) or 1 or 2 + pb.unread(signature, 0, nread); // push back + } + + final int streamHeader = ((int) signature[0] & 0xff) | ((signature[1] << 8) & 0xff00); + if (streamHeader == GZIPInputStream.GZIP_MAGIC) { + return new GZIPInputStream(pb); + } else { + return pb; + } + } + public static void writeChar(final char v, final OutputStream out) throws IOException { out.write(0xff & (v >> 8)); out.write(0xff & v); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/07a7d51b/core/src/main/java/hivemall/utils/io/LimitedInputStream.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/io/LimitedInputStream.java b/core/src/main/java/hivemall/utils/io/LimitedInputStream.java new file mode 100644 index 0000000..54b8482 --- /dev/null +++ b/core/src/main/java/hivemall/utils/io/LimitedInputStream.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.io; + +import hivemall.utils.lang.Preconditions; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +import javax.annotation.CheckForNull; +import javax.annotation.Nonnegative; + +/** + * Input stream which is limited to a certain length. Implementation is based on LimitedInputStream + * in Apache Commons FileUpload. + * + * @link + * https://commons.apache.org/proper/commons-fileupload/apidocs/org/apache/commons/fileupload/util + * /LimitedInputStream.html + */ +public class LimitedInputStream extends FilterInputStream { + + protected final long max; + protected long pos = 0L; + + public LimitedInputStream(@CheckForNull final InputStream in, @Nonnegative final long maxSize) { + super(in); + Preconditions.checkNotNull(in, "Base input stream must not be null"); + this.max = maxSize; + } + + protected void raiseError() throws IOException { + throw new IOException("Exceeded maximum size of input stream: limit = " + max + + " bytes, but pos = " + pos); + } + + private void proceed(@Nonnegative final long bytes) throws IOException { + this.pos += bytes; + if (pos > max) { + raiseError(); + } + } + + @Override + public int read() throws IOException { + final int res = super.read(); + if (res != -1) { + proceed(1L); + } + return res; + } + + @Override + public int read(final byte[] b, final int off, final int len) throws IOException { + final int res = super.read(b, off, len); + if (res > 0) { + proceed(res); + } + return res; + } + + @Override + public long skip(final long n) throws IOException { + final long res = super.skip(n); + if (res > 0) { + proceed(res); + } + return res; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/07a7d51b/core/src/test/java/hivemall/utils/io/LimitedInputStreamTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/utils/io/LimitedInputStreamTest.java b/core/src/test/java/hivemall/utils/io/LimitedInputStreamTest.java new file mode 100644 index 0000000..18d17bf --- /dev/null +++ b/core/src/test/java/hivemall/utils/io/LimitedInputStreamTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.io; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; + +import org.junit.Assert; +import org.junit.Test; + +public class LimitedInputStreamTest { + + @Test + public void testExactSize() throws IOException { + String expected = "abcdef"; + int len = expected.length(); + + InputStream is = new FastByteArrayInputStream(expected.getBytes()); + LimitedInputStream isLimited = new LimitedInputStream(is, len); + + Reader reader = new InputStreamReader(isLimited); + BufferedReader br = new BufferedReader(reader); + + char[] buf = new char[len]; + br.read(buf); + + Assert.assertTrue(expected.equals(new String(buf))); + + br.close(); + } + + @Test + public void testLooseSize() throws IOException { + String expected = "abcdef"; + int len = expected.length(); + + InputStream is = new FastByteArrayInputStream(expected.getBytes()); + LimitedInputStream isLimited = new LimitedInputStream(is, len + 100); // large enough + + Reader reader = new InputStreamReader(isLimited); + BufferedReader br = new BufferedReader(reader); + + char[] buf = new char[len]; + br.read(buf); + + Assert.assertTrue(expected.equals(new String(buf))); + + br.close(); + } + + @Test(expected = IOException.class) + public void testExceed() throws IOException { + String expected = "abcdef"; + int len = expected.length(); + + InputStream is = new FastByteArrayInputStream(expected.getBytes()); + LimitedInputStream isLimited = new LimitedInputStream(is, len - 1); // not enough + + Reader reader = new InputStreamReader(isLimited); + BufferedReader br = new BufferedReader(reader); + + char[] buf = new char[len]; + br.read(buf); + + br.close(); + } + + @Test(expected = NullPointerException.class) + public void testNullInputStream() throws NullPointerException, IOException { + new LimitedInputStream(null, 100).close(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/07a7d51b/nlp/src/main/java/hivemall/nlp/tokenizer/KuromojiUDF.java ---------------------------------------------------------------------- diff --git a/nlp/src/main/java/hivemall/nlp/tokenizer/KuromojiUDF.java b/nlp/src/main/java/hivemall/nlp/tokenizer/KuromojiUDF.java index 425a40f..93b3095 100644 --- a/nlp/src/main/java/hivemall/nlp/tokenizer/KuromojiUDF.java +++ b/nlp/src/main/java/hivemall/nlp/tokenizer/KuromojiUDF.java @@ -1,27 +1,33 @@ /* - * Hivemall: Hive scalable Machine Learning Library + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * Copyright (C) 2015 Makoto YUI - * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * http://www.apache.org/licenses/LICENSE-2.0 * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ package hivemall.nlp.tokenizer; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.IOUtils; +import hivemall.utils.io.HttpUtils; +import java.io.InputStream; +import java.io.InputStreamReader; import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.net.HttpURLConnection; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -30,6 +36,7 @@ import java.util.List; import java.util.Set; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; @@ -44,19 +51,24 @@ import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.ja.JapaneseAnalyzer; import org.apache.lucene.analysis.ja.JapaneseTokenizer; import org.apache.lucene.analysis.ja.JapaneseTokenizer.Mode; +import org.apache.lucene.analysis.ja.dict.UserDictionary; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.util.CharArraySet; @Description( name = "tokenize_ja", - value = "_FUNC_(String line [, const string mode = \"normal\", const list<string> stopWords, const list<string> stopTags])" + value = "_FUNC_(String line [, const string mode = \"normal\", const array<string> stopWords, const array<string> stopTags, const array<string> userDict (or string userDictURL)])" + " - returns tokenized strings in array<string>") @UDFType(deterministic = true, stateful = false) public final class KuromojiUDF extends GenericUDF { + private static final int CONNECT_TIMEOUT_MS = 10000; // 10 sec + private static final int READ_TIMEOUT_MS = 60000; // 60 sec + private static final long MAX_INPUT_STREAM_SIZE = 32L * 1024L * 1024L; // ~32MB private Mode _mode; - private String[] _stopWordsArray; - private Set<String> _stoptags; + private CharArraySet _stopWords; + private Set<String> _stopTags; + private UserDictionary _userDict; // workaround to avoid org.apache.hive.com.esotericsoftware.kryo.KryoException: java.util.ConcurrentModificationException private transient JapaneseAnalyzer _analyzer; @@ -64,15 +76,18 @@ public final class KuromojiUDF extends GenericUDF { @Override public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { final int arglen = arguments.length; - if (arglen < 1 || arglen > 4) { + if (arglen < 1 || arglen > 5) { throw new UDFArgumentException("Invalid number of arguments for `tokenize_ja`: " + arglen); } this._mode = (arglen >= 2) ? tokenizationMode(arguments[1]) : Mode.NORMAL; - this._stopWordsArray = (arglen >= 3) ? HiveUtils.getConstStringArray(arguments[2]) : null; - this._stoptags = (arglen >= 4) ? stopTags(arguments[3]) + this._stopWords = (arglen >= 3) ? stopWords(arguments[2]) + : JapaneseAnalyzer.getDefaultStopSet(); + this._stopTags = (arglen >= 4) ? stopTags(arguments[3]) : JapaneseAnalyzer.getDefaultStopTags(); + this._userDict = (arglen >= 5) ? userDictionary(arguments[4]) : null; + this._analyzer = null; return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector); @@ -80,11 +95,8 @@ public final class KuromojiUDF extends GenericUDF { @Override public List<Text> evaluate(DeferredObject[] arguments) throws HiveException { - JapaneseAnalyzer analyzer = _analyzer; - if (analyzer == null) { - CharArraySet stopwords = stopWords(_stopWordsArray); - analyzer = new JapaneseAnalyzer(null, _mode, stopwords, _stoptags); - this._analyzer = analyzer; + if (_analyzer == null) { + this._analyzer = new JapaneseAnalyzer(_userDict, _mode, _stopWords, _stopTags); } Object arg0 = arguments[0].get(); @@ -96,12 +108,12 @@ public final class KuromojiUDF extends GenericUDF { final List<Text> results = new ArrayList<Text>(32); TokenStream stream = null; try { - stream = analyzer.tokenStream("", line); + stream = _analyzer.tokenStream("", line); if (stream != null) { analyzeTokens(stream, results); } } catch (IOException e) { - IOUtils.closeQuietly(analyzer); + IOUtils.closeQuietly(_analyzer); throw new HiveException(e); } finally { IOUtils.closeQuietly(stream); @@ -115,7 +127,8 @@ public final class KuromojiUDF extends GenericUDF { } @Nonnull - private static Mode tokenizationMode(@Nonnull ObjectInspector oi) throws UDFArgumentException { + private static Mode tokenizationMode(@Nonnull final ObjectInspector oi) + throws UDFArgumentException { final String arg = HiveUtils.getConstString(oi); if (arg == null) { return Mode.NORMAL; @@ -131,14 +144,18 @@ public final class KuromojiUDF extends GenericUDF { mode = JapaneseTokenizer.DEFAULT_MODE; } else { throw new UDFArgumentException( - "Expected NORMAL|SEARCH|EXTENDED|DEFAULT but got an unexpected mode: " + arg); + "Expected NORMAL|SEARCH|EXTENDED|DEFAULT but got an unexpected mode: " + arg); } return mode; } @Nonnull - private static CharArraySet stopWords(@Nonnull final String[] array) + private static CharArraySet stopWords(@Nonnull final ObjectInspector oi) throws UDFArgumentException { + if (HiveUtils.isVoidOI(oi)) { + return JapaneseAnalyzer.getDefaultStopSet(); + } + final String[] array = HiveUtils.getConstStringArray(oi); if (array == null) { return JapaneseAnalyzer.getDefaultStopSet(); } @@ -152,6 +169,9 @@ public final class KuromojiUDF extends GenericUDF { @Nonnull private static Set<String> stopTags(@Nonnull final ObjectInspector oi) throws UDFArgumentException { + if (HiveUtils.isVoidOI(oi)) { + return JapaneseAnalyzer.getDefaultStopTags(); + } final String[] array = HiveUtils.getConstStringArray(oi); if (array == null) { return JapaneseAnalyzer.getDefaultStopTags(); @@ -170,6 +190,89 @@ public final class KuromojiUDF extends GenericUDF { return results; } + @Nullable + private static UserDictionary userDictionary(@Nonnull final ObjectInspector oi) + throws UDFArgumentException { + if (HiveUtils.isConstListOI(oi)) { + return userDictionary(HiveUtils.getConstStringArray(oi)); + } else if (HiveUtils.isConstString(oi)) { + return userDictionary(HiveUtils.getConstString(oi)); + } else { + throw new UDFArgumentException( + "User dictionary MUST be given as an array of constant string or constant string (URL)"); + } + } + + @Nullable + private static UserDictionary userDictionary(@Nullable final String[] userDictArray) + throws UDFArgumentException { + if (userDictArray == null) { + return null; + } + + final StringBuilder builder = new StringBuilder(); + for (String row : userDictArray) { + builder.append(row).append('\n'); + } + final Reader reader = new StringReader(builder.toString()); + try { + return UserDictionary.open(reader); // return null if empty + } catch (Throwable e) { + throw new UDFArgumentException( + "Failed to create user dictionary based on the given array<string>: " + e); + } + } + + @Nullable + private static UserDictionary userDictionary(@Nullable final String userDictURL) + throws UDFArgumentException { + if (userDictURL == null) { + return null; + } + + final HttpURLConnection conn; + try { + conn = HttpUtils.getHttpURLConnection(userDictURL); + } catch (IllegalArgumentException e) { + throw new UDFArgumentException("Failed to create HTTP connection to the URL: " + e); + } catch (IOException e) { + throw new UDFArgumentException("Failed to create HTTP connection to the URL: " + e); + } + + // allow to read as a compressed GZIP file for efficiency + conn.setRequestProperty("Accept-Encoding", "gzip"); + + conn.setConnectTimeout(CONNECT_TIMEOUT_MS); // throw exception from connect() + conn.setReadTimeout(READ_TIMEOUT_MS); // throw exception from getXXX() methods + + final int responseCode; + try { + responseCode = conn.getResponseCode(); + } catch (IOException e) { + throw new UDFArgumentException("Failed to get response code: " + e); + } + if (responseCode != 200) { + throw new UDFArgumentException("Got invalid response code: " + responseCode); + } + + final InputStream is; + try { + is = IOUtils.decodeInputStream(HttpUtils.getLimitedInputStream(conn, + MAX_INPUT_STREAM_SIZE)); + } catch (NullPointerException e) { + throw new UDFArgumentException("Failed to get input stream from the connection: " + e); + } catch (IOException e) { + throw new UDFArgumentException("Failed to get input stream from the connection: " + e); + } + + final Reader reader = new InputStreamReader(is); + try { + return UserDictionary.open(reader); // return null if empty + } catch (Throwable e) { + throw new UDFArgumentException("Failed to parse the file in CSV format: " + e); + } + } + private static void analyzeTokens(@Nonnull TokenStream stream, @Nonnull List<Text> results) throws IOException { // instantiate an attribute placeholder once http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/07a7d51b/nlp/src/test/java/hivemall/nlp/tokenizer/KuromojiUDFTest.java ---------------------------------------------------------------------- diff --git a/nlp/src/test/java/hivemall/nlp/tokenizer/KuromojiUDFTest.java b/nlp/src/test/java/hivemall/nlp/tokenizer/KuromojiUDFTest.java index acd54c5..d0c5e86 100644 --- a/nlp/src/test/java/hivemall/nlp/tokenizer/KuromojiUDFTest.java +++ b/nlp/src/test/java/hivemall/nlp/tokenizer/KuromojiUDFTest.java @@ -1,24 +1,25 @@ /* - * Hivemall: Hive scalable Machine Learning Library + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * Copyright (C) 2015 Makoto YUI - * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) + * http://www.apache.org/licenses/LICENSE-2.0 * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ package hivemall.nlp.tokenizer; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; @@ -40,7 +41,7 @@ import com.esotericsoftware.kryo.io.Output; public class KuromojiUDFTest { @Test - public void testOneArgment() throws UDFArgumentException, IOException { + public void testOneArgument() throws UDFArgumentException, IOException { GenericUDF udf = new KuromojiUDF(); ObjectInspector[] argOIs = new ObjectInspector[1]; // line @@ -50,14 +51,14 @@ public class KuromojiUDFTest { } @Test - public void testTwoArgment() throws UDFArgumentException, IOException { + public void testTwoArgument() throws UDFArgumentException, IOException { GenericUDF udf = new KuromojiUDF(); ObjectInspector[] argOIs = new ObjectInspector[2]; // line argOIs[0] = PrimitiveObjectInspectorFactory.javaStringObjectInspector; // mode argOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, null); + PrimitiveCategory.STRING, null); udf.initialize(argOIs); udf.close(); } @@ -69,7 +70,7 @@ public class KuromojiUDFTest { argOIs[0] = PrimitiveObjectInspectorFactory.javaStringObjectInspector; // mode argOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, new Text("normal")); + PrimitiveCategory.STRING, new Text("normal")); udf.initialize(argOIs); udf.close(); } @@ -82,48 +83,92 @@ public class KuromojiUDFTest { argOIs[0] = PrimitiveObjectInspectorFactory.javaStringObjectInspector; // mode argOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, new Text("unsupported mode")); + PrimitiveCategory.STRING, new Text("unsupported mode")); udf.initialize(argOIs); udf.close(); } @Test - public void testThreeArgment() throws UDFArgumentException, IOException { + public void testThreeArgument() throws UDFArgumentException, IOException { GenericUDF udf = new KuromojiUDF(); ObjectInspector[] argOIs = new ObjectInspector[3]; // line argOIs[0] = PrimitiveObjectInspectorFactory.javaStringObjectInspector; // mode argOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, null); + PrimitiveCategory.STRING, null); // stopWords argOIs[2] = ObjectInspectorFactory.getStandardConstantListObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); + PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); udf.initialize(argOIs); udf.close(); } @Test - public void testFourArgment() throws UDFArgumentException, IOException { + public void testFourArgument() throws UDFArgumentException, IOException { GenericUDF udf = new KuromojiUDF(); ObjectInspector[] argOIs = new ObjectInspector[4]; // line argOIs[0] = PrimitiveObjectInspectorFactory.javaStringObjectInspector; // mode argOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( - PrimitiveCategory.STRING, null); + PrimitiveCategory.STRING, null); // stopWords argOIs[2] = ObjectInspectorFactory.getStandardConstantListObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); + PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); // stopTags argOIs[3] = ObjectInspectorFactory.getStandardConstantListObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); + PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); udf.initialize(argOIs); udf.close(); } @Test - public void testEvalauteOneRow() throws IOException, HiveException { + public void testFiveArgumentArray() throws UDFArgumentException, IOException { + GenericUDF udf = new KuromojiUDF(); + ObjectInspector[] argOIs = new ObjectInspector[5]; + // line + argOIs[0] = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + // mode + argOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.STRING, null); + // stopWords + argOIs[2] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); + // stopTags + argOIs[3] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); + // userDictUrl + argOIs[4] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); + udf.initialize(argOIs); + udf.close(); + } + + @Test + public void testFiveArgumenString() throws UDFArgumentException, IOException { + GenericUDF udf = new KuromojiUDF(); + ObjectInspector[] argOIs = new ObjectInspector[5]; + // line + argOIs[0] = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + // mode + argOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.STRING, null); + // stopWords + argOIs[2] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); + // stopTags + argOIs[3] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, null); + // userDictUrl + argOIs[4] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.STRING, null); + udf.initialize(argOIs); + udf.close(); + } + + @Test + public void testEvaluateOneRow() throws IOException, HiveException { KuromojiUDF udf = new KuromojiUDF(); ObjectInspector[] argOIs = new ObjectInspector[1]; // line @@ -143,7 +188,7 @@ public class KuromojiUDFTest { } @Test - public void testEvalauteTwoRows() throws IOException, HiveException { + public void testEvaluateTwoRows() throws IOException, HiveException { KuromojiUDF udf = new KuromojiUDF(); ObjectInspector[] argOIs = new ObjectInspector[1]; // line @@ -173,6 +218,115 @@ public class KuromojiUDFTest { } @Test + public void testEvaluateUserDictArray() throws IOException, HiveException { + KuromojiUDF udf = new KuromojiUDF(); + ObjectInspector[] argOIs = new ObjectInspector[5]; + // line + argOIs[0] = PrimitiveObjectInspectorFactory.writableStringObjectInspector; + // mode + argOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.STRING, null); + // stopWords + argOIs[2] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector, null); + // stopTags + argOIs[3] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector, null); + // userDictArray (from https://raw.githubusercontent.com/atilika/kuromoji/909fd6b32bf4e9dc86b7599de5c9b50ca8f004a1/kuromoji-core/src/test/resources/userdict.txt) + List<String> userDict = new ArrayList<String>(); + userDict.add("æ¥æ¬çµæ¸æ°è,æ¥æ¬ çµæ¸ æ°è,ããã³ ã±ã¤ã¶ã¤ ã·ã³ãã³,ã«ã¹ã¿ã åè©"); + userDict.add("é¢è¥¿å½é空港,é¢è¥¿ å½é 空港,ã«ã³ãµã¤ ã³ã¯ãµã¤ ã¯ã¦ã³ã¦,ãã¹ãåè©"); + argOIs[4] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector, userDict); + udf.initialize(argOIs); + + DeferredObject[] args = new DeferredObject[1]; + args[0] = new DeferredObject() { + public Text get() throws HiveException { + return new Text("æ¥æ¬çµæ¸æ°èã"); + } + }; + + List<Text> tokens = udf.evaluate(args); + + Assert.assertNotNull(tokens); + Assert.assertEquals(3, tokens.size()); + + udf.close(); + } + + @Test(expected = UDFArgumentException.class) + public void testEvaluateInvalidUserDictURL() throws IOException, HiveException { + KuromojiUDF udf = new KuromojiUDF(); + ObjectInspector[] argOIs = new ObjectInspector[5]; + // line + argOIs[0] = PrimitiveObjectInspectorFactory.writableStringObjectInspector; + // mode + argOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.STRING, null); + // stopWords + argOIs[2] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector, null); + // stopTags + argOIs[3] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector, null); + // userDictUrl + argOIs[4] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.STRING, new Text("http://google.com/")); + udf.initialize(argOIs); + + DeferredObject[] args = new DeferredObject[1]; + args[0] = new DeferredObject() { + public Text get() throws HiveException { + return new Text("ã¯ãã¢ã¸ã®JapaneseAnalyzerã使ã£ã¦ã¿ãããã¹ãã"); + } + }; + + List<Text> tokens = udf.evaluate(args); + Assert.assertNotNull(tokens); + + udf.close(); + } + + @Test + public void testEvaluateUserDictURL() throws IOException, HiveException { + KuromojiUDF udf = new KuromojiUDF(); + ObjectInspector[] argOIs = new ObjectInspector[5]; + // line + argOIs[0] = PrimitiveObjectInspectorFactory.writableStringObjectInspector; + // mode + argOIs[1] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.STRING, null); + // stopWords + argOIs[2] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector, null); + // stopTags + argOIs[3] = ObjectInspectorFactory.getStandardConstantListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector, null); + // userDictUrl (Kuromoji official sample user defined dict on GitHub) + // e.g., "æ¥æ¬çµæ¸æ°è" will be "æ¥æ¬", "çµæ¸", and "æ°è" + argOIs[4] = PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + PrimitiveCategory.STRING, + new Text( + "https://raw.githubusercontent.com/atilika/kuromoji/909fd6b32bf4e9dc86b7599de5c9b50ca8f004a1/kuromoji-core/src/test/resources/userdict.txt")); + udf.initialize(argOIs); + + DeferredObject[] args = new DeferredObject[1]; + args[0] = new DeferredObject() { + public Text get() throws HiveException { + return new Text("ã¯ãã¢ã¸ã®JapaneseAnalyzerã使ã£ã¦ã¿ããæ¥æ¬çµæ¸æ°èã"); + } + }; + + List<Text> tokens = udf.evaluate(args); + + Assert.assertNotNull(tokens); + Assert.assertEquals(7, tokens.size()); + + udf.close(); + } + + @Test public void testSerializeByKryo() throws UDFArgumentException { final KuromojiUDF udf = new KuromojiUDF(); ObjectInspector[] argOIs = new ObjectInspector[1];