http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/docs/gitbook/binaryclass/titanic_rf.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/binaryclass/titanic_rf.md b/docs/gitbook/binaryclass/titanic_rf.md index 29784e0..2b54074 100644 --- a/docs/gitbook/binaryclass/titanic_rf.md +++ b/docs/gitbook/binaryclass/titanic_rf.md @@ -175,7 +175,7 @@ from # Prediction ```sql -SET hivevar:classification=true; +-- SET hivevar:classification=true; set hive.auto.convert.join=true; SET hive.mapjoin.optimized.hashtable=false; SET mapred.reduce.tasks=16; @@ -202,7 +202,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT @@ -319,7 +320,7 @@ from > [116.12055542977338,960.8569891444097,291.08765260103837,469.74671636586226,163.721292772701,120.784769882858,847.9769298113661,554.4617571355476,346.3500941757221,97.42593940113392] > 0.1838351822503962 ```sql -SET hivevar:classification=true; +-- SET hivevar:classification=true; SET hive.mapjoin.optimized.hashtable=false; SET mapred.reduce.tasks=16; @@ -345,7 +346,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/docs/gitbook/multiclass/iris_randomforest.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/multiclass/iris_randomforest.md b/docs/gitbook/multiclass/iris_randomforest.md index b421297..bfc197f 100644 --- a/docs/gitbook/multiclass/iris_randomforest.md +++ b/docs/gitbook/multiclass/iris_randomforest.md @@ -206,7 +206,7 @@ from # Prediction ```sql -set hivevar:classification=true; +-- set hivevar:classification=true; set hive.auto.convert.join=true; set hive.mapjoin.optimized.hashtable=false; @@ -225,7 +225,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- to use the old model in v0.5-rc.1 or later FROM model p @@ -265,7 +266,8 @@ FROM ( -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.5-rc.1 or later p.model_weight, - tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted + tree_predict(p.model_id, p.model, t.features, "-classification") as predicted + -- tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted -- tree_predict_v1(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted as predicted -- to use the old model in v0.5-rc.1 or later FROM ( SELECT http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/mixserv/pom.xml ---------------------------------------------------------------------- diff --git a/mixserv/pom.xml b/mixserv/pom.xml index 0a1b387..ff27b09 100644 --- a/mixserv/pom.xml +++ b/mixserv/pom.xml @@ -16,14 +16,13 @@ specific language governing permissions and limitations under the License. --> -<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <parent> <groupId>org.apache.hivemall</groupId> <artifactId>hivemall</artifactId> - <version>0.5.0-incubating-SNAPSHOT</version> + <version>0.5.1-incubating-SNAPSHOT</version> <relativePath>../pom.xml</relativePath> </parent> @@ -40,49 +39,26 @@ <dependency> <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-common</artifactId> - <version>${hadoop.version}</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-mapreduce-client-core</artifactId> - <version>${hadoop.version}</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.apache.hive</groupId> <artifactId>hive-exec</artifactId> - <version>${hive.version}</version> <scope>provided</scope> - <exclusions> - <exclusion> - <artifactId>jetty</artifactId> - <groupId>org.mortbay.jetty</groupId> - </exclusion> - <exclusion> - <groupId>javax.jdo</groupId> - <artifactId>jdo2-api</artifactId> - </exclusion> - <exclusion> - <groupId>asm-parent</groupId> - <artifactId>asm-parent</artifactId> - </exclusion> - <exclusion> - <groupId>asm</groupId> - <artifactId>asm</artifactId> - </exclusion> - </exclusions> </dependency> <dependency> <groupId>javax.jdo</groupId> <artifactId>jdo2-api</artifactId> - <version>2.3-eb</version> <scope>provided</scope> </dependency> <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> - <version>${guava.version}</version> <scope>provided</scope> </dependency> @@ -103,19 +79,16 @@ <dependency> <groupId>commons-cli</groupId> <artifactId>commons-cli</artifactId> - <version>1.2</version> <scope>compile</scope> </dependency> <dependency> <groupId>commons-logging</groupId> <artifactId>commons-logging</artifactId> - <version>1.0.4</version> <scope>compile</scope> </dependency> <dependency> <groupId>log4j</groupId> <artifactId>log4j</artifactId> - <version>1.2.17</version> <scope>compile</scope> </dependency> <dependency> @@ -130,28 +103,21 @@ <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> - <version>${junit.version}</version> <scope>test</scope> </dependency> <dependency> <groupId>org.mockito</groupId> <artifactId>mockito-all</artifactId> - <version>1.10.19</version> <scope>test</scope> </dependency> </dependencies> <build> - <directory>target</directory> - <outputDirectory>target/classes</outputDirectory> - <finalName>${project.artifactId}-${project.version}</finalName> - <testOutputDirectory>target/test-classes</testOutputDirectory> <plugins> <!-- hivemall-mixserv-xx-fat.jar including all dependencies --> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> - <version>3.1.0</version> <executions> <execution> <id>jar-with-dependencies</id> @@ -170,7 +136,7 @@ <include>commons-cli:commons-cli</include> <include>commons-logging:commons-logging</include> <include>log4j:log4j</include> - <include>io.netty:netty-all</include> + <include>io.netty.netty-all</include> </includes> </artifactSet> <!-- maven-shade-plugin cannot handle the dependency of log4j because @@ -198,8 +164,7 @@ </filter> </filters> <transformers> - <transformer - implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> + <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> <manifestEntries> <Main-Class>hivemall.mix.server.MixServer</Main-Class> <Implementation-Title>${project.name}</Implementation-Title> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/nlp/pom.xml ---------------------------------------------------------------------- diff --git a/nlp/pom.xml b/nlp/pom.xml index dc77c06..782e41d 100644 --- a/nlp/pom.xml +++ b/nlp/pom.xml @@ -16,14 +16,13 @@ specific language governing permissions and limitations under the License. --> -<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <parent> <groupId>org.apache.hivemall</groupId> <artifactId>hivemall</artifactId> - <version>0.5.0-incubating-SNAPSHOT</version> + <version>0.5.1-incubating-SNAPSHOT</version> <relativePath>../pom.xml</relativePath> </parent> @@ -40,77 +39,51 @@ <dependency> <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-common</artifactId> - <version>${hadoop.version}</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-mapreduce-client-core</artifactId> - <version>${hadoop.version}</version> <scope>provided</scope> </dependency> <dependency> <groupId>org.apache.hive</groupId> <artifactId>hive-exec</artifactId> - <version>${hive.version}</version> <scope>provided</scope> - <exclusions> - <exclusion> - <artifactId>jetty</artifactId> - <groupId>org.mortbay.jetty</groupId> - </exclusion> - <exclusion> - <groupId>javax.jdo</groupId> - <artifactId>jdo2-api</artifactId> - </exclusion> - <exclusion> - <groupId>asm-parent</groupId> - <artifactId>asm-parent</artifactId> - </exclusion> - <exclusion> - <groupId>asm</groupId> - <artifactId>asm</artifactId> - </exclusion> - </exclusions> </dependency> <dependency> <groupId>commons-cli</groupId> <artifactId>commons-cli</artifactId> - <version>1.2</version> <scope>provided</scope> </dependency> <dependency> <groupId>commons-logging</groupId> <artifactId>commons-logging</artifactId> - <version>1.0.4</version> <scope>provided</scope> </dependency> <dependency> <groupId>log4j</groupId> <artifactId>log4j</artifactId> - <version>1.2.17</version> <scope>provided</scope> </dependency> <dependency> <groupId>javax.jdo</groupId> <artifactId>jdo2-api</artifactId> - <version>2.3-eb</version> <scope>provided</scope> </dependency> <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> - <version>${guava.version}</version> <scope>provided</scope> </dependency> + + <!-- compile scope --> <dependency> <groupId>org.apache.hivemall</groupId> <artifactId>hivemall-core</artifactId> <version>${project.version}</version> - <scope>provided</scope> + <scope>compile</scope> </dependency> - - <!-- compile scope --> <dependency> <groupId>org.apache.lucene</groupId> <artifactId>lucene-analyzers-kuromoji</artifactId> @@ -128,7 +101,6 @@ <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> - <version>${junit.version}</version> <scope>test</scope> </dependency> <dependency> @@ -140,98 +112,4 @@ </dependencies> - <build> - <directory>target</directory> - <outputDirectory>target/classes</outputDirectory> - <finalName>${project.artifactId}-${project.version}</finalName> - <testOutputDirectory>target/test-classes</testOutputDirectory> - <plugins> - <!-- hivemall-nlp-xx.jar --> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-jar-plugin</artifactId> - <version>2.5</version> - <configuration> - <finalName>${project.artifactId}-${project.version}</finalName> - <outputDirectory>${project.parent.build.directory}</outputDirectory> - </configuration> - </plugin> - <!-- hivemall-nlp-xx-with-dependencies.jar including minimum dependencies --> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-shade-plugin</artifactId> - <version>3.1.0</version> - <executions> - <execution> - <id>jar-with-dependencies</id> - <phase>package</phase> - <goals> - <goal>shade</goal> - </goals> - <configuration> - <finalName>${project.artifactId}-${project.version}-with-dependencies</finalName> - <outputDirectory>${project.parent.build.directory}</outputDirectory> - <minimizeJar>true</minimizeJar> - <createDependencyReducedPom>false</createDependencyReducedPom> - <artifactSet> - <includes> - <include>org.apache.hivemall:hivemall-core</include> - <include>org.apache.lucene:lucene-analyzers-kuromoji</include> - <include>org.apache.lucene:lucene-analyzers-smartcn</include> - <include>org.apache.lucene:lucene-analyzers-common</include> - <include>org.apache.lucene:lucene-core</include> - </includes> - </artifactSet> - <filters> - <filter> - <artifact>*:*</artifact> - <excludes> - <exclude>META-INF/LICENSE.txt</exclude> - </excludes> - </filter> - <filter> - <artifact>org.apache.lucene:lucene-analyzers-kuromoji</artifact> - <includes> - <include>**</include> - </includes> - </filter> - <filter> - <artifact>org.apache.lucene:lucene-analyzers-smartcn</artifact> - <includes> - <include>**</include> - </includes> - </filter> - <filter> - <artifact>org.apache.lucene:lucene-analyzers-common</artifact> - <includes> - <include>**</include> - </includes> - </filter> - <filter> - <artifact>org.apache.lucene:lucene-core</artifact> - <includes> - <include>**</include> - </includes> - </filter> - </filters> - <transformers> - <transformer - implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> - <manifestEntries> - <Implementation-Title>${project.name}</Implementation-Title> - <Implementation-Version>${project.version}</Implementation-Version> - <Implementation-Vendor>${project.organization.name}</Implementation-Vendor> - </manifestEntries> - </transformer> - <transformer implementation="org.apache.maven.plugins.shade.resource.ApacheNoticeResourceTransformer"> - <addHeader>false</addHeader> - </transformer> - </transformers> - </configuration> - </execution> - </executions> - </plugin> - </plugins> - </build> - </project> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/nlp/src/main/java/hivemall/nlp/tokenizer/KuromojiUDF.java ---------------------------------------------------------------------- diff --git a/nlp/src/main/java/hivemall/nlp/tokenizer/KuromojiUDF.java b/nlp/src/main/java/hivemall/nlp/tokenizer/KuromojiUDF.java index 93fd18c..411c89e 100644 --- a/nlp/src/main/java/hivemall/nlp/tokenizer/KuromojiUDF.java +++ b/nlp/src/main/java/hivemall/nlp/tokenizer/KuromojiUDF.java @@ -19,15 +19,19 @@ package hivemall.nlp.tokenizer; import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.io.IOUtils; import hivemall.utils.io.HttpUtils; +import hivemall.utils.io.IOUtils; +import hivemall.utils.lang.ExceptionUtils; +import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.io.IOException; import java.io.Reader; import java.io.StringReader; import java.net.HttpURLConnection; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CodingErrorAction; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -55,8 +59,7 @@ import org.apache.lucene.analysis.ja.dict.UserDictionary; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.util.CharArraySet; -@Description( - name = "tokenize_ja", +@Description(name = "tokenize_ja", value = "_FUNC_(String line [, const string mode = \"normal\", const array<string> stopWords, const array<string> stopTags, const array<string> userDict (or string userDictURL)])" + " - returns tokenized strings in array<string>") @UDFType(deterministic = true, stateful = false) @@ -77,20 +80,21 @@ public final class KuromojiUDF extends GenericUDF { public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { final int arglen = arguments.length; if (arglen < 1 || arglen > 5) { - throw new UDFArgumentException("Invalid number of arguments for `tokenize_ja`: " - + arglen); + throw new UDFArgumentException( + "Invalid number of arguments for `tokenize_ja`: " + arglen); } this._mode = (arglen >= 2) ? tokenizationMode(arguments[1]) : Mode.NORMAL; - this._stopWords = (arglen >= 3) ? stopWords(arguments[2]) - : JapaneseAnalyzer.getDefaultStopSet(); - this._stopTags = (arglen >= 4) ? stopTags(arguments[3]) - : JapaneseAnalyzer.getDefaultStopTags(); + this._stopWords = + (arglen >= 3) ? stopWords(arguments[2]) : JapaneseAnalyzer.getDefaultStopSet(); + this._stopTags = + (arglen >= 4) ? stopTags(arguments[3]) : JapaneseAnalyzer.getDefaultStopTags(); this._userDict = (arglen >= 5) ? userDictionary(arguments[4]) : null; this._analyzer = null; - return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + return ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.writableStringObjectInspector); } @Override @@ -219,7 +223,8 @@ public final class KuromojiUDF extends GenericUDF { return UserDictionary.open(reader); // return null if empty } catch (Throwable e) { throw new UDFArgumentException( - "Failed to create user dictionary based on the given array<string>: " + e); + "Failed to create user dictionary based on the given array<string>: " + + builder.toString() + '\n' + ExceptionUtils.prettyPrintStackTrace(e)); } } @@ -234,7 +239,8 @@ public final class KuromojiUDF extends GenericUDF { try { conn = HttpUtils.getHttpURLConnection(userDictURL); } catch (IllegalArgumentException | IOException e) { - throw new UDFArgumentException("Failed to create HTTP connection to the URL: " + e); + throw new UDFArgumentException("Failed to create HTTP connection to the URL: " + + userDictURL + '\n' + ExceptionUtils.prettyPrintStackTrace(e)); } // allow to read as a compressed GZIP file for efficiency @@ -247,7 +253,8 @@ public final class KuromojiUDF extends GenericUDF { try { responseCode = conn.getResponseCode(); } catch (IOException e) { - throw new UDFArgumentException("Failed to get response code: " + e); + throw new UDFArgumentException("Failed to get response code: " + userDictURL + '\n' + + ExceptionUtils.prettyPrintStackTrace(e)); } if (responseCode != 200) { throw new UDFArgumentException("Got invalid response code: " + responseCode); @@ -255,17 +262,24 @@ public final class KuromojiUDF extends GenericUDF { final InputStream is; try { - is = IOUtils.decodeInputStream(HttpUtils.getLimitedInputStream(conn, - MAX_INPUT_STREAM_SIZE)); + is = IOUtils.decodeInputStream( + HttpUtils.getLimitedInputStream(conn, MAX_INPUT_STREAM_SIZE)); } catch (NullPointerException | IOException e) { - throw new UDFArgumentException("Failed to get input stream from the connection: " + e); + throw new UDFArgumentException("Failed to get input stream from the connection: " + + userDictURL + '\n' + ExceptionUtils.prettyPrintStackTrace(e)); } - final Reader reader = new InputStreamReader(is); + CharsetDecoder decoder = + StandardCharsets.UTF_8.newDecoder() + .onMalformedInput(CodingErrorAction.REPORT) + .onUnmappableCharacter(CodingErrorAction.REPORT); + final Reader reader = new InputStreamReader(is, decoder); try { return UserDictionary.open(reader); // return null if empty } catch (Throwable e) { - throw new UDFArgumentException("Failed to parse the file in CSV format: " + e); + throw new UDFArgumentException( + "Failed to parse the file in CSV format (UTF-8 encoding is expected): " + + userDictURL + '\n' + ExceptionUtils.prettyPrintStackTrace(e)); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index e9c19dd..e594006 100644 --- a/pom.xml +++ b/pom.xml @@ -16,13 +16,12 @@ specific language governing permissions and limitations under the License. --> -<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>org.apache.hivemall</groupId> <artifactId>hivemall</artifactId> - <version>0.5.0-incubating-SNAPSHOT</version> + <version>0.5.1-incubating-SNAPSHOT</version> <parent> <groupId>org.apache</groupId> @@ -51,7 +50,8 @@ <url>https://git-wip-us.apache.org/repos/asf/incubator-hivemall.git</url> <connection>scm:git:https://git-wip-us.apache.org/repos/asf/incubator-hivemall.git</connection> <developerConnection>scm:git:https://git-wip-us.apache.org/repos/asf/incubator-hivemall.git</developerConnection> - </scm> + <tag>v0.5.0-rc1</tag> + </scm> <mailingLists> <mailingList> @@ -152,8 +152,8 @@ <name>Tsuyoshi Ozawa</name> <email>ozawa[at]apache.org</email> <url>https://people.apache.org/~ozawa/</url> - <organization></organization> - <organizationUrl></organizationUrl> + <organization /> + <organizationUrl /> <roles> <role>PPMC Member</role> </roles> @@ -249,15 +249,14 @@ <module>nlp</module> <module>xgboost</module> <module>mixserv</module> + <module>spark</module> + <module>dist</module> </modules> <properties> - <java.source.version>1.7</java.source.version> - <java.target.version>1.7</java.target.version> + <main.basedir>${project.basedir}</main.basedir> <maven.compiler.source>1.7</maven.compiler.source> <maven.compiler.target>1.7</maven.compiler.target> - <scala.version>2.11.8</scala.version> - <scala.binary.version>2.11</scala.binary.version> <maven.build.timestamp.format>yyyy</maven.build.timestamp.format> <build.year>${maven.build.timestamp}</build.year> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> @@ -269,9 +268,9 @@ <guava.version>11.0.2</guava.version> <junit.version>4.12</junit.version> <dependency.locations.enabled>false</dependency.locations.enabled> - <main.basedir>${project.basedir}</main.basedir> - <maven-enforcer-plugin.version>3.0.0-M1</maven-enforcer-plugin.version> + <maven-enforcer.requireMavenVersion>[3.3.1,)</maven-enforcer.requireMavenVersion> <surefire.version>2.19.1</surefire.version> + <xgboost.version>0.7-rc2</xgboost.version> </properties> <distributionManagement> @@ -315,113 +314,6 @@ <profiles> <profile> - <id>spark-2.2</id> - <modules> - <module>spark/spark-2.2</module> - <module>spark/spark-common</module> - </modules> - <properties> - <spark.version>2.2.0</spark.version> - <spark.binary.version>2.2</spark.binary.version> - </properties> - <build> - <plugins> - <!-- Spark-2.2 only supports Java 8 --> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-enforcer-plugin</artifactId> - <version>${maven-enforcer-plugin.version}</version> - <executions> - <execution> - <id>enforce-versions</id> - <phase>validate</phase> - <goals> - <goal>enforce</goal> - </goals> - <configuration> - <rules> - <requireProperty> - <property>java.source.version</property> - <regex>1.8</regex> - <regexMessage>When -Pspark-2.2 set, java.source.version must be 1.8</regexMessage> - </requireProperty> - <requireProperty> - <property>java.target.version</property> - <regex>1.8</regex> - <regexMessage>When -Pspark-2.2 set, java.target.version must be 1.8</regexMessage> - </requireProperty> - </rules> - </configuration> - </execution> - </executions> - </plugin> - </plugins> - </build> - </profile> - <profile> - <id>spark-2.1</id> - <modules> - <module>spark/spark-2.1</module> - <module>spark/spark-common</module> - </modules> - <properties> - <spark.version>2.1.1</spark.version> - <spark.binary.version>2.1</spark.binary.version> - </properties> - </profile> - <profile> - <id>spark-2.0</id> - <modules> - <module>spark/spark-2.0</module> - <module>spark/spark-common</module> - </modules> - <properties> - <spark.version>2.0.2</spark.version> - <spark.binary.version>2.0</spark.binary.version> - </properties> - </profile> - <profile> - <id>java7</id> - <properties> - <spark.test.jvm.opts>-ea -Xms768m -Xmx1024m -XX:PermSize=128m -XX:MaxPermSize=512m -XX:ReservedCodeCacheSize=512m</spark.test.jvm.opts> - </properties> - <activation> - <jdk>[,1.8)</jdk> <!-- version < 1.8 --> - </activation> - </profile> - <profile> - <id>java8</id> - <properties> - <spark.test.jvm.opts>-ea -Xms768m -Xmx1024m -XX:MetaspaceSize=128m -XX:MaxMetaspaceSize=512m -XX:ReservedCodeCacheSize=512m</spark.test.jvm.opts> - </properties> - <activation> - <jdk>[1.8,)</jdk> <!-- version >= 1.8 --> - </activation> - </profile> - <profile> - <id>compile-xgboost</id> - <build> - <plugins> - <plugin> - <artifactId>exec-maven-plugin</artifactId> - <groupId>org.codehaus.mojo</groupId> - <executions> - <execution> - <id>native</id> - <phase>generate-sources</phase> - <goals> - <goal>exec</goal> - </goals> - <configuration> - <executable>./bin/build_xgboost.sh</executable> - </configuration> - </execution> - </executions> - </plugin> - </plugins> - </build> - </profile> - <profile> <id>doclint-java8-disable</id> <activation> <jdk>[1.8,)</jdk> @@ -432,6 +324,110 @@ </profile> </profiles> + <dependencyManagement> + <dependencies> + <!-- provided scope --> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-common</artifactId> + <version>${hadoop.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-mapreduce-client-core</artifactId> + <version>${hadoop.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hive</groupId> + <artifactId>hive-exec</artifactId> + <version>${hive.version}</version> + <scope>provided</scope> + <exclusions> + <exclusion> + <artifactId>jetty</artifactId> + <groupId>org.mortbay.jetty</groupId> + </exclusion> + <exclusion> + <groupId>javax.jdo</groupId> + <artifactId>jdo2-api</artifactId> + </exclusion> + <exclusion> + <groupId>asm-parent</groupId> + <artifactId>asm-parent</artifactId> + </exclusion> + <exclusion> + <groupId>asm</groupId> + <artifactId>asm</artifactId> + </exclusion> + </exclusions> + </dependency> + <dependency> + <groupId>commons-cli</groupId> + <artifactId>commons-cli</artifactId> + <version>1.2</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>commons-logging</groupId> + <artifactId>commons-logging</artifactId> + <version>1.0.4</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>log4j</groupId> + <artifactId>log4j</artifactId> + <version>1.2.17</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>javax.jdo</groupId> + <artifactId>jdo2-api</artifactId> + <version>2.3-eb</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <version>${guava.version}</version> + <scope>provided</scope> + </dependency> + + <!-- test scope --> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <version>${junit.version}</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-all</artifactId> + <version>1.10.19</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <version>1.10.19</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.powermock</groupId> + <artifactId>powermock-module-junit4</artifactId> + <version>1.6.3</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.powermock</groupId> + <artifactId>powermock-api-mockito</artifactId> + <version>1.6.3</version> + <scope>test</scope> + </dependency> + </dependencies> + </dependencyManagement> + <build> <directory>target</directory> <outputDirectory>target/classes</outputDirectory> @@ -441,6 +437,25 @@ <pluginManagement> <plugins> <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <version>3.0.2</version> + <configuration> + <finalName>${project.artifactId}-${project.version}</finalName> + <outputDirectory>${main.basedir}/target</outputDirectory> + </configuration> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-shade-plugin</artifactId> + <version>3.1.0</version> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-enforcer-plugin</artifactId> + <version>3.0.0-M1</version> + </plugin> + <plugin> <!-- mvn formatter:format --> <groupId>net.revelc.code</groupId> <artifactId>formatter-maven-plugin</artifactId> @@ -475,6 +490,11 @@ <useDefaultExcludes>false</useDefaultExcludes> <excludes> <exclude>docs/gitbook/node_modules/**</exclude> + <exclude>target/</exclude> + <exclude>src/main/java/hivemall/utils/codec/Base91.java</exclude> + <exclude>src/main/java/hivemall/utils/math/FastMath.java</exclude> + <exclude>src/main/java/hivemall/smile/classification/DecisionTree.java</exclude> + <exclude>src/main/java/hivemall/smile/regression/RegressionTree.java</exclude> </excludes> <encoding>UTF-8</encoding> <headerDefinitions> @@ -575,14 +595,42 @@ <artifactId>maven-enforcer-plugin</artifactId> <executions> <execution> - <id>enforce-maven</id> + <id>enforce-JAVA_HOME-is-set</id> + <goals> + <goal>enforce</goal> + </goals> + <configuration> + <rules> + <requireEnvironmentVariable> + <variableName>JAVA_HOME</variableName> + </requireEnvironmentVariable> + </rules> + <fail>true</fail> + </configuration> + </execution> + <execution> + <id>enforce-JAVA8_HOME-is-set</id> + <goals> + <goal>enforce</goal> + </goals> + <configuration> + <rules> + <requireEnvironmentVariable> + <variableName>JAVA8_HOME</variableName> + </requireEnvironmentVariable> + </rules> + <fail>true</fail> + </configuration> + </execution> + <execution> + <id>required-maven-version</id> <goals> <goal>enforce</goal> </goals> <configuration> <rules> <requireMavenVersion> - <version>[3.3.1,)</version> + <version>${maven-enforcer.requireMavenVersion}</version> </requireMavenVersion> </rules> </configuration> @@ -610,8 +658,8 @@ <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <configuration> - <source>${java.source.version}</source> - <target>${java.target.version}</target> + <source>${maven.compiler.source}</source> + <target>${maven.compiler.target}</target> <debug>true</debug> <debuglevel>lines,vars,source</debuglevel> <encoding>UTF-8</encoding> @@ -688,30 +736,6 @@ </dependencies> </plugin> <!-- end mvn site --> - <plugin> - <groupId>org.scalastyle</groupId> - <artifactId>scalastyle-maven-plugin</artifactId> - <version>0.8.0</version> - <configuration> - <verbose>false</verbose> - <failOnViolation>true</failOnViolation> - <includeTestSourceDirectory>true</includeTestSourceDirectory> - <failOnWarning>false</failOnWarning> - <sourceDirectory>${basedir}/src/main/scala</sourceDirectory> - <testSourceDirectory>${basedir}/src/test/scala</testSourceDirectory> - <configLocation>spark/spark-common/scalastyle-config.xml</configLocation> - <outputFile>${basedir}/target/scalastyle-output.xml</outputFile> - <inputEncoding>${project.build.sourceEncoding}</inputEncoding> - <outputEncoding>${project.reporting.outputEncoding}</outputEncoding> - </configuration> - <executions> - <execution> - <goals> - <goal>check</goal> - </goals> - </execution> - </executions> - </plugin> <!-- mvn apache-rat:check --> <plugin> <groupId>org.apache.rat</groupId> http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/pom.xml ---------------------------------------------------------------------- diff --git a/spark/common/pom.xml b/spark/common/pom.xml new file mode 100644 index 0000000..a6262e8 --- /dev/null +++ b/spark/common/pom.xml @@ -0,0 +1,64 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <parent> + <groupId>org.apache.hivemall</groupId> + <artifactId>hivemall-spark</artifactId> + <version>0.5.1-incubating-SNAPSHOT</version> + <relativePath>../pom.xml</relativePath> + </parent> + + <artifactId>hivemall-spark-common</artifactId> + <name>Hivemall on Spark Common</name> + <packaging>jar</packaging> + + <properties> + <main.basedir>${project.parent.parent.basedir}</main.basedir> + </properties> + + <dependencies> + <!-- provided scope --> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-common</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-mapreduce-client-core</artifactId> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hive</groupId> + <artifactId>hive-exec</artifactId> + <scope>provided</scope> + </dependency> + + <!-- compile scope --> + <dependency> + <groupId>org.apache.hivemall</groupId> + <artifactId>hivemall-core</artifactId> + <scope>compile</scope> + </dependency> + </dependencies> + +</project> + http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/common/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTFWrapper.java b/spark/common/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTFWrapper.java new file mode 100644 index 0000000..cf10ed7 --- /dev/null +++ b/spark/common/src/main/java/hivemall/dataset/LogisticRegressionDataGeneratorUDTFWrapper.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.dataset; + +import hivemall.UDTFWithOptions; + +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.Random; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.Collector; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; + +/** + * A wrapper of [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]]. This wrapper is needed + * because Spark cannot handle HadoopUtils#getTaskId() correctly. + */ +@Description(name = "lr_datagen", + value = "_FUNC_(options string) - Generates a logistic regression dataset") +public final class LogisticRegressionDataGeneratorUDTFWrapper extends UDTFWithOptions { + private transient LogisticRegressionDataGeneratorUDTF udtf = + new LogisticRegressionDataGeneratorUDTF(); + + @Override + protected Options getOptions() { + Options options = null; + try { + Method m = udtf.getClass().getDeclaredMethod("getOptions"); + m.setAccessible(true); + options = (Options) m.invoke(udtf); + } catch (Exception e) { + e.printStackTrace(); + } + return options; + } + + @SuppressWarnings("all") + @Override + protected CommandLine processOptions(ObjectInspector[] objectInspectors) + throws UDFArgumentException { + CommandLine commands = null; + try { + Method m = udtf.getClass().getDeclaredMethod("processOptions"); + m.setAccessible(true); + commands = (CommandLine) m.invoke(udtf, objectInspectors); + } catch (Exception e) { + e.printStackTrace(); + } + return commands; + } + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + try { + // Extract a collector for LogisticRegressionDataGeneratorUDTF + Field collector = GenericUDTF.class.getDeclaredField("collector"); + collector.setAccessible(true); + udtf.setCollector((Collector) collector.get(this)); + + // To avoid HadoopUtils#getTaskId() + Class<?> clazz = udtf.getClass(); + Field rnd1 = clazz.getDeclaredField("rnd1"); + Field rnd2 = clazz.getDeclaredField("rnd2"); + Field r_seed = clazz.getDeclaredField("r_seed"); + r_seed.setAccessible(true); + final long seed = r_seed.getLong(udtf) + (int) Thread.currentThread().getId(); + rnd1.setAccessible(true); + rnd2.setAccessible(true); + rnd1.set(udtf, new Random(seed)); + rnd2.set(udtf, new Random(seed + 1)); + } catch (Exception e) { + e.printStackTrace(); + } + return udtf.initialize(argOIs); + } + + @Override + public void process(Object[] objects) throws HiveException { + udtf.process(objects); + } + + @Override + public void close() throws HiveException { + udtf.close(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/java/hivemall/ftvec/AddBiasUDFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/common/src/main/java/hivemall/ftvec/AddBiasUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/AddBiasUDFWrapper.java new file mode 100644 index 0000000..b454fd9 --- /dev/null +++ b/spark/common/src/main/java/hivemall/ftvec/AddBiasUDFWrapper.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec; + +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; + +/** + * A wrapper of [[hivemall.ftvec.AddBiasUDF]]. + * + * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark cannot handle List<> + * as a return type in Hive UDF. Therefore, the type must be passed via ObjectInspector. + */ +@Description(name = "add_bias", + value = "_FUNC_(features in array<string>) - Returns features with a bias as array<string>") +@UDFType(deterministic = true, stateful = false) +public class AddBiasUDFWrapper extends GenericUDF { + private AddBiasUDF udf = new AddBiasUDF(); + private ListObjectInspector argumentOI = null; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) { + throw new UDFArgumentLengthException( + "add_bias() has an single arguments: array<string> features"); + } + + switch (arguments[0].getCategory()) { + case LIST: + argumentOI = (ListObjectInspector) arguments[0]; + ObjectInspector elmOI = argumentOI.getListElementObjectInspector(); + if (elmOI.getCategory().equals(Category.PRIMITIVE)) { + if (((PrimitiveObjectInspector) elmOI).getPrimitiveCategory() == PrimitiveCategory.STRING) { + break; + } + } + default: + throw new UDFArgumentTypeException(0, "Type mismatch: features"); + } + + return ObjectInspectorFactory.getStandardListObjectInspector(argumentOI.getListElementObjectInspector()); + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert (arguments.length == 1); + @SuppressWarnings("unchecked") + final List<String> input = (List<String>) argumentOI.getList(arguments[0].get()); + return udf.evaluate(input); + } + + @Override + public String getDisplayString(String[] children) { + return "add_bias(" + Arrays.toString(children) + ")"; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/java/hivemall/ftvec/AddFeatureIndexUDFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/common/src/main/java/hivemall/ftvec/AddFeatureIndexUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/AddFeatureIndexUDFWrapper.java new file mode 100644 index 0000000..0b687db --- /dev/null +++ b/spark/common/src/main/java/hivemall/ftvec/AddFeatureIndexUDFWrapper.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec; + +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +/** + * A wrapper of [[hivemall.ftvec.AddFeatureIndexUDF]]. + * + * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark cannot handle List<> + * as a return type in Hive UDF. Therefore, the type must be passed via ObjectInspector. + */ +@Description( + name = "add_feature_index", + value = "_FUNC_(dense features in array<double>) - Returns a feature vector with feature indices") +@UDFType(deterministic = true, stateful = false) +public class AddFeatureIndexUDFWrapper extends GenericUDF { + private AddFeatureIndexUDF udf = new AddFeatureIndexUDF(); + private ListObjectInspector argumentOI = null; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) { + throw new UDFArgumentLengthException( + "add_feature_index() has an single arguments: array<double> features"); + } + + switch (arguments[0].getCategory()) { + case LIST: + argumentOI = (ListObjectInspector) arguments[0]; + ObjectInspector elmOI = argumentOI.getListElementObjectInspector(); + if (elmOI.getCategory().equals(Category.PRIMITIVE)) { + if (((PrimitiveObjectInspector) elmOI).getPrimitiveCategory() == PrimitiveCategory.DOUBLE) { + break; + } + } + default: + throw new UDFArgumentTypeException(0, "Type mismatch: features"); + } + + return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector); + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert (arguments.length == 1); + @SuppressWarnings("unchecked") + final List<Double> input = (List<Double>) argumentOI.getList(arguments[0].get()); + return udf.evaluate(input); + } + + @Override + public String getDisplayString(String[] children) { + return "add_feature_index(" + Arrays.toString(children) + ")"; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/java/hivemall/ftvec/ExtractFeatureUDFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/common/src/main/java/hivemall/ftvec/ExtractFeatureUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/ExtractFeatureUDFWrapper.java new file mode 100644 index 0000000..5924468 --- /dev/null +++ b/spark/common/src/main/java/hivemall/ftvec/ExtractFeatureUDFWrapper.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec; + +import java.util.Arrays; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +/** + * A wrapper of [[hivemall.ftvec.ExtractFeatureUDF]]. + * + * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark cannot handle List<> + * as a return type in Hive UDF. Therefore, the type must be passed via ObjectInspector. + */ +@Description(name = "extract_feature", + value = "_FUNC_(feature in string) - Returns a parsed feature as string") +@UDFType(deterministic = true, stateful = false) +public class ExtractFeatureUDFWrapper extends GenericUDF { + private ExtractFeatureUDF udf = new ExtractFeatureUDF(); + private PrimitiveObjectInspector argumentOI = null; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) { + throw new UDFArgumentLengthException( + "extract_feature() has an single arguments: string feature"); + } + + argumentOI = (PrimitiveObjectInspector) arguments[0]; + if (argumentOI.getPrimitiveCategory() != PrimitiveCategory.STRING) { + throw new UDFArgumentTypeException(0, "Type mismatch: feature"); + } + + return PrimitiveObjectInspectorFactory.javaStringObjectInspector; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert (arguments.length == 1); + final String input = (String) argumentOI.getPrimitiveJavaObject(arguments[0].get()); + return udf.evaluate(input); + } + + @Override + public String getDisplayString(String[] children) { + return "extract_feature(" + Arrays.toString(children) + ")"; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/java/hivemall/ftvec/ExtractWeightUDFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/common/src/main/java/hivemall/ftvec/ExtractWeightUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/ExtractWeightUDFWrapper.java new file mode 100644 index 0000000..8580247 --- /dev/null +++ b/spark/common/src/main/java/hivemall/ftvec/ExtractWeightUDFWrapper.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec; + +import java.util.Arrays; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +/** + * A wrapper of [[hivemall.ftvec.ExtractWeightUDF]]. + * + * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark cannot handle List<> + * as a return type in Hive UDF. Therefore, the type must be passed via ObjectInspector. + */ +@Description(name = "extract_weight", + value = "_FUNC_(feature in string) - Returns the weight of a feature as string") +@UDFType(deterministic = true, stateful = false) +public class ExtractWeightUDFWrapper extends GenericUDF { + private ExtractWeightUDF udf = new ExtractWeightUDF(); + private PrimitiveObjectInspector argumentOI = null; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) { + throw new UDFArgumentLengthException( + "extract_weight() has an single arguments: string feature"); + } + + argumentOI = (PrimitiveObjectInspector) arguments[0]; + if (argumentOI.getPrimitiveCategory() != PrimitiveCategory.STRING) { + throw new UDFArgumentTypeException(0, "Type mismatch: feature"); + } + + return PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.DOUBLE); + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert (arguments.length == 1); + final String input = (String) argumentOI.getPrimitiveJavaObject(arguments[0].get()); + return udf.evaluate(input); + } + + @Override + public String getDisplayString(String[] children) { + return "extract_weight(" + Arrays.toString(children) + ")"; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/java/hivemall/ftvec/SortByFeatureUDFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/common/src/main/java/hivemall/ftvec/SortByFeatureUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/SortByFeatureUDFWrapper.java new file mode 100644 index 0000000..584be6c --- /dev/null +++ b/spark/common/src/main/java/hivemall/ftvec/SortByFeatureUDFWrapper.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec; + +import java.util.Arrays; +import java.util.Map; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; + +/** + * A wrapper of [[hivemall.ftvec.SortByFeatureUDF]]. + * + * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark cannot handle Map<> + * as a return type in Hive UDF. Therefore, the type must be passed via ObjectInspector. + */ +@Description(name = "sort_by_feature", + value = "_FUNC_(map in map<int,float>) - Returns a sorted map") +@UDFType(deterministic = true, stateful = false) +public class SortByFeatureUDFWrapper extends GenericUDF { + private SortByFeatureUDF udf = new SortByFeatureUDF(); + private MapObjectInspector argumentOI = null; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) { + throw new UDFArgumentLengthException( + "sorted_by_feature() has an single arguments: map<int, float> map"); + } + + switch (arguments[0].getCategory()) { + case MAP: + argumentOI = (MapObjectInspector) arguments[0]; + ObjectInspector keyOI = argumentOI.getMapKeyObjectInspector(); + ObjectInspector valueOI = argumentOI.getMapValueObjectInspector(); + if (keyOI.getCategory().equals(Category.PRIMITIVE) + && valueOI.getCategory().equals(Category.PRIMITIVE)) { + final PrimitiveCategory keyCategory = ((PrimitiveObjectInspector) keyOI).getPrimitiveCategory(); + final PrimitiveCategory valueCategory = ((PrimitiveObjectInspector) valueOI).getPrimitiveCategory(); + if (keyCategory == PrimitiveCategory.INT + && valueCategory == PrimitiveCategory.FLOAT) { + break; + } + } + default: + throw new UDFArgumentTypeException(0, "Type mismatch: map"); + } + + + return ObjectInspectorFactory.getStandardMapObjectInspector( + argumentOI.getMapKeyObjectInspector(), argumentOI.getMapValueObjectInspector()); + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert (arguments.length == 1); + @SuppressWarnings("unchecked") + final Map<IntWritable, FloatWritable> input = (Map<IntWritable, FloatWritable>) argumentOI.getMap(arguments[0].get()); + return udf.evaluate(input); + } + + @Override + public String getDisplayString(String[] children) { + return "sort_by_feature(" + Arrays.toString(children) + ")"; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/common/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDFWrapper.java b/spark/common/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDFWrapper.java new file mode 100644 index 0000000..db533be --- /dev/null +++ b/spark/common/src/main/java/hivemall/ftvec/scaling/L2NormalizationUDFWrapper.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.ftvec.scaling; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters.Converter; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.Text; + +/** + * A wrapper of [[hivemall.ftvec.scaling.L2NormalizationUDF]]. + * + * NOTE: This is needed to avoid the issue of Spark reflection. That is, spark-1.3 cannot handle + * List<> as a return type in Hive UDF. The type must be passed via ObjectInspector. This issues has + * been reported in SPARK-6747, so a future release of Spark makes the wrapper obsolete. + */ +public class L2NormalizationUDFWrapper extends GenericUDF { + private L2NormalizationUDF udf = new L2NormalizationUDF(); + + private transient List<Text> retValue = new ArrayList<Text>(); + private transient Converter toListText = null; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) { + throw new UDFArgumentLengthException("normalize() has an only single argument."); + } + + switch (arguments[0].getCategory()) { + case LIST: + ObjectInspector elmOI = ((ListObjectInspector) arguments[0]).getListElementObjectInspector(); + if (elmOI.getCategory().equals(Category.PRIMITIVE)) { + if (((PrimitiveObjectInspector) elmOI).getPrimitiveCategory() == PrimitiveCategory.STRING) { + break; + } + } + default: + throw new UDFArgumentTypeException(0, + "normalize() must have List[String] as an argument, but " + + arguments[0].getTypeName() + " was found."); + } + + // Create a ObjectInspector converter for arguments + ObjectInspector outputElemOI = ObjectInspectorFactory.getReflectionObjectInspector( + Text.class, ObjectInspectorOptions.JAVA); + ObjectInspector outputOI = ObjectInspectorFactory.getStandardListObjectInspector(outputElemOI); + toListText = ObjectInspectorConverters.getConverter(arguments[0], outputOI); + + ObjectInspector listElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + ObjectInspector returnElemOI = ObjectInspectorUtils.getStandardObjectInspector(listElemOI); + return ObjectInspectorFactory.getStandardListObjectInspector(returnElemOI); + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert (arguments.length == 1); + @SuppressWarnings("unchecked") + final List<Text> input = (List<Text>) toListText.convert(arguments[0].get()); + retValue = udf.evaluate(input); + return retValue; + } + + @Override + public String getDisplayString(String[] children) { + return "normalize(" + Arrays.toString(children) + ")"; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/java/hivemall/knn/lsh/MinHashesUDFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/common/src/main/java/hivemall/knn/lsh/MinHashesUDFWrapper.java b/spark/common/src/main/java/hivemall/knn/lsh/MinHashesUDFWrapper.java new file mode 100644 index 0000000..d3bcbe6 --- /dev/null +++ b/spark/common/src/main/java/hivemall/knn/lsh/MinHashesUDFWrapper.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.knn.lsh; + +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; + +/** A wrapper of [[hivemall.knn.lsh.MinHashesUDF]]. */ +@Description( + name = "minhashes", + value = "_FUNC_(features in array<string>, noWeight in boolean) - Returns hashed features as array<int>") +@UDFType(deterministic = true, stateful = false) +public class MinHashesUDFWrapper extends GenericUDF { + private MinHashesUDF udf = new MinHashesUDF(); + private ListObjectInspector featuresOI = null; + private PrimitiveObjectInspector noWeightOI = null; + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 2) { + throw new UDFArgumentLengthException( + "minhashes() has 2 arguments: array<string> features, boolean noWeight"); + } + + // Check argument types + switch (arguments[0].getCategory()) { + case LIST: + featuresOI = (ListObjectInspector) arguments[0]; + ObjectInspector elmOI = featuresOI.getListElementObjectInspector(); + if (elmOI.getCategory().equals(Category.PRIMITIVE)) { + if (((PrimitiveObjectInspector) elmOI).getPrimitiveCategory() == PrimitiveCategory.STRING) { + break; + } + } + default: + throw new UDFArgumentTypeException(0, "Type mismatch: features"); + } + + noWeightOI = (PrimitiveObjectInspector) arguments[1]; + if (noWeightOI.getPrimitiveCategory() != PrimitiveCategory.BOOLEAN) { + throw new UDFArgumentException("Type mismatch: noWeight"); + } + + return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.getPrimitiveWritableObjectInspector(PrimitiveCategory.INT)); + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert (arguments.length == 2); + @SuppressWarnings("unchecked") + final List<String> features = (List<String>) featuresOI.getList(arguments[0].get()); + final Boolean noWeight = PrimitiveObjectInspectorUtils.getBoolean(arguments[1].get(), + noWeightOI); + return udf.evaluate(features, noWeight); + } + + @Override + public String getDisplayString(String[] children) { + /** + * TODO: Need to return hive-specific type names. + */ + return "minhashes(" + Arrays.toString(children) + ")"; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/java/hivemall/tools/mapred/RowIdUDFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/common/src/main/java/hivemall/tools/mapred/RowIdUDFWrapper.java b/spark/common/src/main/java/hivemall/tools/mapred/RowIdUDFWrapper.java new file mode 100644 index 0000000..f386223 --- /dev/null +++ b/spark/common/src/main/java/hivemall/tools/mapred/RowIdUDFWrapper.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.mapred; + +import java.util.UUID; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +/** An alternative implementation of [[hivemall.tools.mapred.RowIdUDF]]. */ +@Description( + name = "rowid", + value = "_FUNC_() - Returns a generated row id of a form {TASK_ID}-{UUID}-{SEQUENCE_NUMBER}") +@UDFType(deterministic = false, stateful = true) +public class RowIdUDFWrapper extends GenericUDF { + // RowIdUDF is directly used because spark cannot + // handle HadoopUtils#getTaskId(). + + private long sequence; + private long taskId; + + public RowIdUDFWrapper() { + this.sequence = 0L; + this.taskId = Thread.currentThread().getId(); + } + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 0) { + throw new UDFArgumentLengthException("row_number() has no argument."); + } + + return PrimitiveObjectInspectorFactory.javaStringObjectInspector; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + assert (arguments.length == 0); + sequence++; + /** + * TODO: Check if it is unique over all tasks in executors of Spark. + */ + return taskId + "-" + UUID.randomUUID() + "-" + sequence; + } + + @Override + public String getDisplayString(String[] children) { + return "row_number()"; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/scala/hivemall/HivemallException.scala ---------------------------------------------------------------------- diff --git a/spark/common/src/main/scala/hivemall/HivemallException.scala b/spark/common/src/main/scala/hivemall/HivemallException.scala new file mode 100644 index 0000000..53f6756 --- /dev/null +++ b/spark/common/src/main/scala/hivemall/HivemallException.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall + +class HivemallException(message: String, cause: Throwable) + extends Exception(message, cause) { + + def this(message: String) = this(message, null) +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3a718713/spark/common/src/main/scala/org/apache/spark/ml/feature/HivemallLabeledPoint.scala ---------------------------------------------------------------------- diff --git a/spark/common/src/main/scala/org/apache/spark/ml/feature/HivemallLabeledPoint.scala b/spark/common/src/main/scala/org/apache/spark/ml/feature/HivemallLabeledPoint.scala new file mode 100644 index 0000000..3fb2d18 --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/ml/feature/HivemallLabeledPoint.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.ml.feature + +import java.util.StringTokenizer + +import scala.collection.mutable.ListBuffer + +import hivemall.HivemallException + +// Used for DataFrame#explode +case class HivemallFeature(feature: String) + +/** + * Class that represents the features and labels of a data point for Hivemall. + * + * @param label Label for this data point. + * @param features List of features for this data point. + */ +case class HivemallLabeledPoint(label: Float = 0.0f, features: Seq[String]) { + override def toString: String = { + "%s,%s".format(label, features.mkString("[", ",", "]")) + } +} + +object HivemallLabeledPoint { + + // Simple parser for HivemallLabeledPoint + def parse(s: String): HivemallLabeledPoint = { + val (label, features) = s.indexOf(',') match { + case d if d > 0 => (s.substring(0, d), s.substring(d + 1)) + case _ => ("0.0", "[]") // Dummy + } + HivemallLabeledPoint(label.toFloat, parseTuple(new StringTokenizer(features, "[],", true))) + } + + // TODO: Support to parse rows without labels + private[this] def parseTuple(tokenizer: StringTokenizer): Seq[String] = { + val items = ListBuffer.empty[String] + var parsing = true + var allowDelim = false + while (parsing && tokenizer.hasMoreTokens()) { + val token = tokenizer.nextToken() + if (token == "[") { + items ++= parseTuple(tokenizer) + parsing = false + allowDelim = true + } else if (token == ",") { + if (allowDelim) { + allowDelim = false + } else { + throw new HivemallException("Found ',' at a wrong position.") + } + } else if (token == "]") { + parsing = false + } else { + items.append(token) + allowDelim = true + } + } + if (parsing) { + throw new HivemallException(s"A tuple must end with ']'.") + } + items + } +}