Repository: incubator-hivemall
Updated Branches:
  refs/heads/master e158f58ac -> bcae1534a


Close #45: [HIVEMALL-71] Handle null values and add a unit Tests to RescaleUDF


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/bcae1534
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/bcae1534
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/bcae1534

Branch: refs/heads/master
Commit: bcae1534a2894d98746210ed20985c115a99678e
Parents: e158f58
Author: Yuming Wang <wgy...@gmail.com>
Authored: Thu Feb 16 15:16:45 2017 +0900
Committer: myui <yuin...@gmail.com>
Committed: Thu Feb 16 15:16:45 2017 +0900

----------------------------------------------------------------------
 .gitignore                                      |  1 +
 .../java/hivemall/ftvec/scaling/RescaleUDF.java | 85 ++++++++++++-----
 .../hivemall/ftvec/scaling/RescaleUDFTest.java  | 99 ++++++++++++++++++++
 3 files changed, 161 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bcae1534/.gitignore
----------------------------------------------------------------------
diff --git a/.gitignore b/.gitignore
index 55d6a7d..3b44c62 100644
--- a/.gitignore
+++ b/.gitignore
@@ -22,3 +22,4 @@ spark/bin/zinc-*
 .classpath
 .project
 metastore_db
+.java-version

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bcae1534/core/src/main/java/hivemall/ftvec/scaling/RescaleUDF.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/hivemall/ftvec/scaling/RescaleUDF.java 
b/core/src/main/java/hivemall/ftvec/scaling/RescaleUDF.java
index 21a30d5..a3e4799 100644
--- a/core/src/main/java/hivemall/ftvec/scaling/RescaleUDF.java
+++ b/core/src/main/java/hivemall/ftvec/scaling/RescaleUDF.java
@@ -20,8 +20,12 @@ package hivemall.ftvec.scaling;
 
 import static hivemall.utils.hadoop.WritableUtils.val;
 
+import javax.annotation.CheckForNull;
+import javax.annotation.Nullable;
+
 import org.apache.hadoop.hive.ql.exec.Description;
 import org.apache.hadoop.hive.ql.exec.UDF;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.udf.UDFType;
 import org.apache.hadoop.io.FloatWritable;
 import org.apache.hadoop.io.Text;
@@ -36,49 +40,82 @@ import org.apache.hadoop.io.Text;
 @UDFType(deterministic = true, stateful = false)
 public final class RescaleUDF extends UDF {
 
-    public FloatWritable evaluate(final float value, final float min, final 
float max) {
-        return val(min_max_normalization(value, min, max));
+    @Nullable
+    public FloatWritable evaluate(@Nullable final Double value, @CheckForNull 
final Double min,
+            @CheckForNull final Double max) throws HiveException {
+        return evaluate(double2Float(value), double2Float(min), 
double2Float(max));
     }
 
-    public FloatWritable evaluate(final double value, final double min, final 
double max) {
+    @Nullable
+    public FloatWritable evaluate(@Nullable final Float value, @CheckForNull 
final Float min,
+            @CheckForNull final Float max) throws HiveException {
+        if (value == null) {
+            return null;
+        }
+
+        if (min == null)
+            throw new HiveException("min should not be null");
+        if (max == null)
+            throw new HiveException("max should not be null");
+
         return val(min_max_normalization(value, min, max));
     }
 
-    public Text evaluate(final String s, final double min, final double max) {
-        String[] fv = s.split(":");
-        if (fv.length != 2) {
-            throw new IllegalArgumentException("Invalid feature value 
representation: " + s);
-        }
-        double v = Float.parseFloat(fv[1]);
-        float scaled_v = min_max_normalization(v, min, max);
-        String ret = fv[0] + ':' + Float.toString(scaled_v);
-        return val(ret);
+    @Nullable
+    public Text evaluate(@Nullable final String s, @CheckForNull final Double 
min,
+            @CheckForNull final Double max) throws HiveException {
+        return evaluate(s, double2Float(min), double2Float(max));
     }
 
-    public Text evaluate(final String s, final float min, final float max) {
-        String[] fv = s.split(":");
+    @Nullable
+    public Text evaluate(@Nullable final String s, @CheckForNull final Float 
min,
+            @CheckForNull final Float max) throws HiveException {
+        if (s == null) {
+            return null;
+        }
+
+        if (min == null)
+            throw new HiveException("min should not be null");
+        if (max == null)
+            throw new HiveException("max should not be null");
+
+        final String[] fv = s.split(":");
         if (fv.length != 2) {
-            throw new IllegalArgumentException("Invalid feature value 
representation: " + s);
+            throw new HiveException(String.format("Invalid feature value " + 
"representation: %s",
+                s));
+        }
+        float v;
+        try {
+            v = Float.parseFloat(fv[1]);
+        } catch (NumberFormatException e) {
+            throw new HiveException(String.format("Invalid feature value "
+                    + "representation: %s, %s can't parse to float.", s, 
fv[1]));
         }
-        float v = Float.parseFloat(fv[1]);
-        float scaled_v = min_max_normalization(v, min, max);
-        String ret = fv[0] + ':' + Float.toString(scaled_v);
+
+        float scaled_v = min_max_normalization(v, min.floatValue(), 
max.floatValue());
+        String ret = fv[0] + ':' + scaled_v;
         return val(ret);
     }
 
-    private static float min_max_normalization(final float value, final float 
min, final float max) {
+    private static float min_max_normalization(final float value, final float 
min, final float max)
+            throws HiveException {
+        if (min > max) {
+            throw new HiveException("min value `" + min + "` SHOULD be less 
than max value `" + max
+                    + '`');
+        }
         if (min == max) {
             return 0.5f;
         }
         return (value - min) / (max - min);
     }
 
-    private static float min_max_normalization(final double value, final 
double min,
-            final double max) {
-        if (min == max) {
-            return 0.5f;
+    @Nullable
+    private static Float double2Float(@Nullable final Double value) {
+        if (value == null) {
+            return null;
+        } else {
+            return Float.valueOf(value.floatValue());
         }
-        return (float) ((value - min) / (max - min));
     }
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/bcae1534/core/src/test/java/hivemall/ftvec/scaling/RescaleUDFTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/hivemall/ftvec/scaling/RescaleUDFTest.java 
b/core/src/test/java/hivemall/ftvec/scaling/RescaleUDFTest.java
new file mode 100644
index 0000000..887511f
--- /dev/null
+++ b/core/src/test/java/hivemall/ftvec/scaling/RescaleUDFTest.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.ftvec.scaling;
+
+import static org.junit.Assert.assertEquals;
+import hivemall.utils.hadoop.WritableUtils;
+
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.junit.Before;
+import org.junit.Test;
+
+public class RescaleUDFTest {
+
+    RescaleUDF udf = null;
+
+    @Before
+    public void init() {
+        udf = new RescaleUDF();
+    }
+
+    @Test
+    public void test() throws Exception {
+        assertEquals(WritableUtils.val(0.5f), udf.evaluate(1f, 1f, 1f));
+        assertEquals(WritableUtils.val(0.5f), udf.evaluate(0.1d, 0.1d, 0.1d));
+        assertEquals(WritableUtils.val("1:0.5"), udf.evaluate("1:1", 1f, 1f));
+        assertEquals(WritableUtils.val("1:0.5"), udf.evaluate("1:1", 0.1d, 
0.1d));
+    }
+
+    @Test(expected = HiveException.class)
+    public void testFloatMinIsNull() throws Exception {
+        udf.evaluate(1f, null, 1f);
+    }
+
+    @Test(expected = HiveException.class)
+    public void testFloatMaxIsNull() throws Exception {
+        udf.evaluate(1f, 1f, null);
+    }
+
+    @Test(expected = HiveException.class)
+    public void testDoubleMinIsNull() throws Exception {
+        udf.evaluate(0.1, null, 0.1);
+    }
+
+    @Test(expected = HiveException.class)
+    public void testDoubleMaxIsNull() throws Exception {
+        udf.evaluate(0.1, 0.1, null);
+    }
+
+    @Test(expected = HiveException.class)
+    public void testBothNull() throws Exception {
+        udf.evaluate(0.1, null, null);
+    }
+
+    @Test(expected = HiveException.class)
+    public void testIllegalArgumentException1() throws Exception {
+        udf.evaluate("1:", 0.1d, 0.1d);
+    }
+
+    @Test(expected = HiveException.class)
+    public void testStringMaxNull() throws Exception {
+        udf.evaluate("1:1", null, 1d);
+    }
+
+    @Test(expected = HiveException.class)
+    public void testStringMinNull() throws Exception {
+        udf.evaluate("1:1", 1d, null);
+    }
+
+    @Test(expected = HiveException.class)
+    public void testCannotParseNumber() throws Exception {
+        udf.evaluate("1:string", 0.1d, 0.1d);
+    }
+
+    public void testMinMaxEquals() throws Exception {
+        udf.evaluate(0.1d, 0.1d, 0.1d);
+    }
+
+    @Test(expected = HiveException.class)
+    public void testInvalidMinMax() throws Exception {
+        udf.evaluate(0.1d, 0.2d, 0.1d);
+    }
+
+}

Reply via email to