http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/tools/sanity/RaiseErrorUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/sanity/RaiseErrorUDF.java b/core/src/main/java/hivemall/tools/sanity/RaiseErrorUDF.java new file mode 100644 index 0000000..194085c --- /dev/null +++ b/core/src/main/java/hivemall/tools/sanity/RaiseErrorUDF.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.sanity; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDF; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; + +@Description(name = "raise_error", value = "_FUNC_() or _FUNC_(string msg) - Throws an error") +@UDFType(deterministic = true, stateful = false) +public final class RaiseErrorUDF extends UDF { + + public boolean evaluate() throws HiveException { + throw new HiveException(); + } + + public boolean evaluate(String errorMessage) throws HiveException { + throw new HiveException(errorMessage); + } + +}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/tools/vector/VectorAddUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/vector/VectorAddUDF.java b/core/src/main/java/hivemall/tools/vector/VectorAddUDF.java new file mode 100644 index 0000000..8442ae3 --- /dev/null +++ b/core/src/main/java/hivemall/tools/vector/VectorAddUDF.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.vector; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.StringUtils; + +import java.util.Arrays; +import java.util.List; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; + +@Description(name = "vector_add", + value = "_FUNC_(array<NUMBER> x, array<NUMBER> y) - Perform vector ADD operation.") +@UDFType(deterministic = true, stateful = false) +public final class VectorAddUDF extends GenericUDF { + + private ListObjectInspector xOI, yOI; + private PrimitiveObjectInspector xElemOI, yElemOI; + private boolean floatingPoints; + + @Override + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 2) { + throw new UDFArgumentLengthException("Expected 2 arguments, but got " + argOIs.length); + } + + this.xOI = HiveUtils.asListOI(argOIs[0]); + this.yOI = HiveUtils.asListOI(argOIs[1]); + this.xElemOI = HiveUtils.asNumberOI(xOI.getListElementObjectInspector()); + this.yElemOI = HiveUtils.asNumberOI(yOI.getListElementObjectInspector()); + + if (HiveUtils.isIntegerOI(xElemOI) && HiveUtils.isIntegerOI(yElemOI)) { + this.floatingPoints = false; + return ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaLongObjectInspector); + } else { + this.floatingPoints = true; + return ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector); + } + } + + @Nullable + @Override + public List<?> evaluate(@Nonnull DeferredObject[] args) throws HiveException { + final Object arg0 = args[0].get(); + final Object arg1 = args[1].get(); + if (arg0 == null || arg1 == null) { + return null; + } + + final int xLen = xOI.getListLength(arg0); + final int yLen = yOI.getListLength(arg1); + if (xLen != yLen) { + throw new HiveException( + "vector lengths do not match. x=" + xOI.getList(arg0) + ", y=" + yOI.getList(arg1)); + } + + if (floatingPoints) { + return evaluateDouble(arg0, arg1, xLen); + } else { + return evaluateLong(arg0, arg1, xLen); + } + } + + @Nonnull + private List<Double> evaluateDouble(@Nonnull final Object vecX, @Nonnull final Object vecY, + @Nonnegative final int size) { + final Double[] arr = new Double[size]; + for (int i = 0; i < size; i++) { + Object x = xOI.getListElement(vecX, i); + Object y = yOI.getListElement(vecY, i); + if (x == null || y == null) { + continue; + } + double xd = PrimitiveObjectInspectorUtils.getDouble(x, xElemOI); + double yd = PrimitiveObjectInspectorUtils.getDouble(y, yElemOI); + double v = xd + yd; + arr[i] = Double.valueOf(v); + } + return Arrays.asList(arr); + } + + @Nonnull + private List<Long> evaluateLong(@Nonnull final Object vecX, @Nonnull final Object vecY, + @Nonnegative final int size) { + final Long[] arr = new Long[size]; + for (int i = 0; i < size; i++) { + Object x = xOI.getListElement(vecX, i); + Object y = yOI.getListElement(vecY, i); + if (x == null || y == null) { + continue; + } + long xd = PrimitiveObjectInspectorUtils.getLong(x, xElemOI); + long yd = PrimitiveObjectInspectorUtils.getLong(y, yElemOI); + long v = xd + yd; + arr[i] = Long.valueOf(v); + } + return Arrays.asList(arr); + } + + @Override + public String getDisplayString(String[] args) { + return "vector_add(" + StringUtils.join(args, ',') + ")"; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/tools/vector/VectorDotUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/tools/vector/VectorDotUDF.java b/core/src/main/java/hivemall/tools/vector/VectorDotUDF.java new file mode 100644 index 0000000..958595a --- /dev/null +++ b/core/src/main/java/hivemall/tools/vector/VectorDotUDF.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.tools.vector; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.StringUtils; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; + +@Description(name = "vector_dot", + value = "_FUNC_(array<NUMBER> x, array<NUMBER> y) - Performs vector dot product.", + extended = "_FUNC_(array<NUMBER> x, NUMBER y) - Performs vector multiplication") +@UDFType(deterministic = true, stateful = false) +public final class VectorDotUDF extends GenericUDF { + + private Evaluator evalutor; + + @Override + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 2) { + throw new UDFArgumentLengthException("Expected 2 arguments, but got " + argOIs.length); + } + + ObjectInspector argOI0 = argOIs[0]; + if (!HiveUtils.isNumberListOI(argOI0)) { + throw new UDFArgumentException( + "Expected array<number> for the first argument: " + argOI0.getTypeName()); + } + ListObjectInspector xListOI = HiveUtils.asListOI(argOI0); + + ObjectInspector argOI1 = argOIs[1]; + if (HiveUtils.isNumberListOI(argOI1)) { + this.evalutor = new Dot2DVectors(xListOI, HiveUtils.asListOI(argOI1)); + } else if (HiveUtils.isNumberOI(argOI1)) { + this.evalutor = new Multiply2D1D(xListOI, argOI1); + } else { + throw new UDFArgumentException( + "Expected array<number> or number for the send argument: " + argOI1.getTypeName()); + } + + return ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector); + } + + @Override + public List<Double> evaluate(DeferredObject[] args) throws HiveException { + final Object arg0 = args[0].get(); + final Object arg1 = args[1].get(); + if (arg0 == null || arg1 == null) { + return null; + } + + return evalutor.dot(arg0, arg1); + } + + interface Evaluator extends Serializable { + + @Nonnull + List<Double> dot(@Nonnull Object x, @Nonnull Object y) throws HiveException; + + } + + static final class Multiply2D1D implements Evaluator { + private static final long serialVersionUID = -9090211833041797311L; + + private final ListObjectInspector xListOI; + private final PrimitiveObjectInspector xElemOI; + private final PrimitiveObjectInspector yOI; + + Multiply2D1D(@Nonnull ListObjectInspector xListOI, @Nonnull ObjectInspector yOI) + throws UDFArgumentTypeException { + this.xListOI = xListOI; + this.xElemOI = HiveUtils.asNumberOI(xListOI.getListElementObjectInspector()); + this.yOI = HiveUtils.asNumberOI(yOI); + } + + @Override + public List<Double> dot(@Nonnull Object x, @Nonnull Object y) throws HiveException { + final double yd = PrimitiveObjectInspectorUtils.getDouble(y, yOI); + + final int xLen = xListOI.getListLength(x); + final Double[] arr = new Double[xLen]; + for (int i = 0; i < xLen; i++) { + Object xi = xListOI.getListElement(x, i); + if (xi == null) { + continue; + } + double xd = PrimitiveObjectInspectorUtils.getDouble(xi, xElemOI); + double v = xd * yd; + arr[i] = Double.valueOf(v); + } + + return Arrays.asList(arr); + } + + } + + static final class Dot2DVectors implements Evaluator { + private static final long serialVersionUID = -8783159823009951347L; + + private final ListObjectInspector xListOI, yListOI; + private final PrimitiveObjectInspector xElemOI, yElemOI; + + Dot2DVectors(@Nonnull ListObjectInspector xListOI, @Nonnull ListObjectInspector yListOI) + throws UDFArgumentTypeException { + this.xListOI = xListOI; + this.yListOI = yListOI; + this.xElemOI = HiveUtils.asNumberOI(xListOI.getListElementObjectInspector()); + this.yElemOI = HiveUtils.asNumberOI(yListOI.getListElementObjectInspector()); + } + + @Override + public List<Double> dot(@Nonnull Object x, @Nonnull Object y) throws HiveException { + final int xLen = xListOI.getListLength(x); + final int yLen = yListOI.getListLength(y); + if (xLen != yLen) { + throw new HiveException("vector lengths do not match. x=" + xListOI.getList(x) + + ", y=" + yListOI.getList(y)); + } + + final Double[] arr = new Double[xLen]; + for (int i = 0; i < xLen; i++) { + Object xi = xListOI.getListElement(x, i); + Object yi = yListOI.getListElement(y, i); + if (xi == null || yi == null) { + continue; + } + double xd = PrimitiveObjectInspectorUtils.getDouble(xi, xElemOI); + double yd = PrimitiveObjectInspectorUtils.getDouble(yi, yElemOI); + double v = xd * yd; + arr[i] = Double.valueOf(v); + } + + return Arrays.asList(arr); + } + + } + + @Override + public String getDisplayString(String[] args) { + return "vector_dot(" + StringUtils.join(args, ',') + ")"; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/utils/collections/DoubleRingBuffer.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/collections/DoubleRingBuffer.java b/core/src/main/java/hivemall/utils/collections/DoubleRingBuffer.java index c7a17c4..d67670d 100644 --- a/core/src/main/java/hivemall/utils/collections/DoubleRingBuffer.java +++ b/core/src/main/java/hivemall/utils/collections/DoubleRingBuffer.java @@ -71,6 +71,10 @@ public final class DoubleRingBuffer implements Iterable<Double> { return this; } + public double head() { + return ring[head]; + } + public void toArray(@Nonnull final double[] dst) { toArray(dst, true); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 44475eb..1be1a01 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -75,11 +75,14 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInsp import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantStringObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils; import org.apache.hadoop.io.BooleanWritable; import org.apache.hadoop.io.FloatWritable; @@ -102,8 +105,8 @@ public final class HiveUtils { if (o instanceof LongWritable) { long l = ((LongWritable) o).get(); if (l > 0x7fffffffL) { - throw new IllegalArgumentException("feature index must be less than " - + Integer.MAX_VALUE + ", but was " + l); + throw new IllegalArgumentException( + "feature index must be less than " + Integer.MAX_VALUE + ", but was " + l); } return (int) l; } @@ -330,6 +333,15 @@ public final class HiveUtils { return ObjectInspectorUtils.isConstantObjectInspector(oi) && isListOI(oi); } + public static boolean isConstStringListOI(@Nonnull final ObjectInspector oi) + throws UDFArgumentException { + if (!isConstListOI(oi)) { + return false; + } + ListObjectInspector listOI = (ListObjectInspector) oi; + return isStringOI(listOI.getListElementObjectInspector()); + } + public static boolean isConstString(@Nonnull final ObjectInspector oi) { return ObjectInspectorUtils.isConstantObjectInspector(oi) && isStringOI(oi); } @@ -491,8 +503,8 @@ public final class HiveUtils { } ConstantObjectInspector constOI = (ConstantObjectInspector) oi; if (constOI.getCategory() != Category.LIST) { - throw new UDFArgumentException("argument must be an array: " - + TypeInfoUtils.getTypeInfoFromObjectInspector(oi)); + throw new UDFArgumentException( + "argument must be an array: " + TypeInfoUtils.getTypeInfoFromObjectInspector(oi)); } final List<?> lst = (List<?>) constOI.getWritableConstantValue(); if (lst == null) { @@ -518,11 +530,12 @@ public final class HiveUtils { } ConstantObjectInspector constOI = (ConstantObjectInspector) oi; if (constOI.getCategory() != Category.LIST) { - throw new UDFArgumentException("argument must be an array: " - + TypeInfoUtils.getTypeInfoFromObjectInspector(oi)); + throw new UDFArgumentException( + "argument must be an array: " + TypeInfoUtils.getTypeInfoFromObjectInspector(oi)); } StandardConstantListObjectInspector listOI = (StandardConstantListObjectInspector) constOI; - PrimitiveObjectInspector elemOI = HiveUtils.asDoubleCompatibleOI(listOI.getListElementObjectInspector()); + PrimitiveObjectInspector elemOI = + HiveUtils.asDoubleCompatibleOI(listOI.getListElementObjectInspector()); final List<?> lst = listOI.getWritableConstantValue(); if (lst == null) { @@ -783,8 +796,8 @@ public final class HiveUtils { } final int length = listOI.getListLength(argObj); if (out.length != length) { - throw new UDFArgumentException("Dimension mismatched. Expected: " + out.length - + ", Actual: " + length); + throw new UDFArgumentException( + "Dimension mismatched. Expected: " + out.length + ", Actual: " + length); } for (int i = 0; i < length; i++) { Object o = listOI.getListElement(argObj, i); @@ -809,8 +822,8 @@ public final class HiveUtils { } final int length = listOI.getListLength(argObj); if (out.length != length) { - throw new UDFArgumentException("Dimension mismatched. Expected: " + out.length - + ", Actual: " + length); + throw new UDFArgumentException( + "Dimension mismatched. Expected: " + out.length + ", Actual: " + length); } for (int i = 0; i < length; i++) { Object o = listOI.getListElement(argObj, i); @@ -945,8 +958,8 @@ public final class HiveUtils { case STRING: break; default: - throw new UDFArgumentTypeException(0, "Unxpected type '" + argOI.getTypeName() - + "' is passed."); + throw new UDFArgumentTypeException(0, + "Unxpected type '" + argOI.getTypeName() + "' is passed."); } return oi; } @@ -972,8 +985,8 @@ public final class HiveUtils { case TIMESTAMP: break; default: - throw new UDFArgumentTypeException(0, "Unxpected type '" + argOI.getTypeName() - + "' is passed."); + throw new UDFArgumentTypeException(0, + "Unxpected type '" + argOI.getTypeName() + "' is passed."); } return oi; } @@ -993,15 +1006,15 @@ public final class HiveUtils { case BYTE: break; default: - throw new UDFArgumentTypeException(0, "Unxpected type '" + argOI.getTypeName() - + "' is passed."); + throw new UDFArgumentTypeException(0, + "Unxpected type '" + argOI.getTypeName() + "' is passed."); } return oi; } @Nonnull - public static PrimitiveObjectInspector asDoubleCompatibleOI(@Nonnull final ObjectInspector argOI) - throws UDFArgumentTypeException { + public static PrimitiveObjectInspector asDoubleCompatibleOI( + @Nonnull final ObjectInspector argOI) throws UDFArgumentTypeException { if (argOI.getCategory() != Category.PRIMITIVE) { throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted but " + argOI.getTypeName() + " is passed."); @@ -1164,8 +1177,8 @@ public final class HiveUtils { @Nonnull public static LazyString lazyString(@Nonnull final String str, final byte escapeChar) { - LazyStringObjectInspector oi = LazyPrimitiveObjectInspectorFactory.getLazyStringObjectInspector( - false, escapeChar); + LazyStringObjectInspector oi = + LazyPrimitiveObjectInspectorFactory.getLazyStringObjectInspector(false, escapeChar); return lazyString(str, oi); } @@ -1182,17 +1195,36 @@ public final class HiveUtils { @Nonnull public static LazyInteger lazyInteger(@Nonnull final int v) { - LazyInteger lazy = new LazyInteger( - LazyPrimitiveObjectInspectorFactory.LAZY_INT_OBJECT_INSPECTOR); + LazyInteger lazy = + new LazyInteger(LazyPrimitiveObjectInspectorFactory.LAZY_INT_OBJECT_INSPECTOR); lazy.getWritableObject().set(v); return lazy; } @Nonnull public static LazyLong lazyLong(@Nonnull final long v) { - LazyLong lazy = new LazyLong(LazyPrimitiveObjectInspectorFactory.LAZY_LONG_OBJECT_INSPECTOR); + LazyLong lazy = + new LazyLong(LazyPrimitiveObjectInspectorFactory.LAZY_LONG_OBJECT_INSPECTOR); lazy.getWritableObject().set(v); return lazy; } + @Nonnull + public static ObjectInspector getObjectInspector(@Nonnull final String typeString, + boolean preferWritable) { + TypeInfo typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(typeString); + if (preferWritable) { + return TypeInfoUtils.getStandardWritableObjectInspectorFromTypeInfo(typeInfo); + } else { + return TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(typeInfo); + } + } + + @Nonnull + public static WritableConstantStringObjectInspector getConstStringObjectInspector( + @Nonnull final String str) { + return (WritableConstantStringObjectInspector) PrimitiveObjectInspectorFactory.getPrimitiveWritableConstantObjectInspector( + TypeInfoFactory.stringTypeInfo, new Text(str)); + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java b/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java new file mode 100644 index 0000000..1315537 --- /dev/null +++ b/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java @@ -0,0 +1,715 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +// This file codes borrowed from +// - org.apache.hive.hcatalog.data.JsonSerDe +package hivemall.utils.hadoop; + +import hivemall.utils.io.FastByteArrayInputStream; +import hivemall.utils.lang.Preconditions; + +import java.io.IOException; +import java.nio.charset.CharacterCodingException; +import java.sql.Date; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.annotation.CheckForNull; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.hadoop.hive.common.type.HiveChar; +import org.apache.hadoop.hive.common.type.HiveDecimal; +import org.apache.hadoop.hive.common.type.HiveVarchar; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.serde2.SerDeException; +import org.apache.hadoop.hive.serde2.SerDeUtils; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.UnionObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.ByteObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DateObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveCharObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveVarcharObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.ShortObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.TimestampObjectInspector; +import org.apache.hadoop.hive.serde2.typeinfo.BaseCharTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory; +import org.apache.hadoop.io.Text; +import org.apache.hive.hcatalog.common.HCatException; +import org.apache.hive.hcatalog.data.schema.HCatFieldSchema; +import org.apache.hive.hcatalog.data.schema.HCatFieldSchema.Type; +import org.apache.hive.hcatalog.data.schema.HCatSchema; +import org.apache.hive.hcatalog.data.schema.HCatSchemaUtils; +import org.codehaus.jackson.JsonFactory; +import org.codehaus.jackson.JsonParseException; +import org.codehaus.jackson.JsonParser; +import org.codehaus.jackson.JsonToken; + +public final class JsonSerdeUtils { + + /** + * Serialize Hive objects as Text. + */ + @Nonnull + public static Text serialize(@Nullable final Object obj, @Nonnull final ObjectInspector oi) + throws SerDeException { + return serialize(obj, oi, null); + } + + /** + * Serialize Hive objects as Text. + */ + @Nonnull + public static Text serialize(@Nullable final Object obj, @Nonnull final ObjectInspector oi, + @Nullable final List<String> columnNames) throws SerDeException { + final StringBuilder sb = new StringBuilder(); + switch (oi.getCategory()) { + case STRUCT: + StructObjectInspector soi = (StructObjectInspector) oi; + serializeStruct(sb, obj, soi, columnNames); + break; + case LIST: + ListObjectInspector loi = (ListObjectInspector) oi; + serializeList(sb, obj, loi); + break; + case MAP: + MapObjectInspector moi = (MapObjectInspector) oi; + serializeMap(sb, obj, moi); + break; + case PRIMITIVE: + PrimitiveObjectInspector poi = (PrimitiveObjectInspector) oi; + serializePrimitive(sb, obj, poi); + break; + default: + throw new SerDeException("Unknown type in ObjectInspector: " + oi.getCategory()); + } + + return new Text(sb.toString()); + } + + /** + * Serialize Hive objects as Text. + */ + private static void serializeStruct(@Nonnull final StringBuilder sb, @Nullable final Object obj, + @Nonnull final StructObjectInspector soi, @Nullable final List<String> columnNames) + throws SerDeException { + if (obj == null) { + sb.append("null"); + } else { + final List<? extends StructField> structFields = soi.getAllStructFieldRefs(); + sb.append(SerDeUtils.LBRACE); + if (columnNames == null) { + for (int i = 0, len = structFields.size(); i < len; i++) { + String colName = structFields.get(i).getFieldName(); + if (i > 0) { + sb.append(SerDeUtils.COMMA); + } + appendWithQuotes(sb, colName); + sb.append(SerDeUtils.COLON); + buildJSONString(sb, soi.getStructFieldData(obj, structFields.get(i)), + structFields.get(i).getFieldObjectInspector()); + } + } else if (columnNames.size() == structFields.size()) { + for (int i = 0, len = structFields.size(); i < len; i++) { + if (i > 0) { + sb.append(SerDeUtils.COMMA); + } + String colName = columnNames.get(i); + appendWithQuotes(sb, colName); + sb.append(SerDeUtils.COLON); + buildJSONString(sb, soi.getStructFieldData(obj, structFields.get(i)), + structFields.get(i).getFieldObjectInspector()); + } + } else { + Collections.sort(columnNames); + final List<String> found = new ArrayList<>(columnNames.size()); + for (int i = 0, len = structFields.size(); i < len; i++) { + String colName = structFields.get(i).getFieldName(); + if (Collections.binarySearch(columnNames, colName) < 0) { + continue; + } + if (!found.isEmpty()) { + sb.append(SerDeUtils.COMMA); + } + appendWithQuotes(sb, colName); + sb.append(SerDeUtils.COLON); + buildJSONString(sb, soi.getStructFieldData(obj, structFields.get(i)), + structFields.get(i).getFieldObjectInspector()); + found.add(colName); + } + if (found.size() != columnNames.size()) { + ArrayList<String> expected = new ArrayList<>(columnNames); + expected.removeAll(found); + throw new SerDeException("Could not find some fields: " + expected); + } + } + sb.append(SerDeUtils.RBRACE); + } + } + + @Nonnull + private static void serializeList(@Nonnull final StringBuilder sb, @Nullable final Object obj, + @Nullable final ListObjectInspector loi) throws SerDeException { + ObjectInspector listElementObjectInspector = loi.getListElementObjectInspector(); + List<?> olist = loi.getList(obj); + + if (olist == null) { + sb.append("null"); + } else { + sb.append(SerDeUtils.LBRACKET); + for (int i = 0; i < olist.size(); i++) { + if (i > 0) { + sb.append(SerDeUtils.COMMA); + } + buildJSONString(sb, olist.get(i), listElementObjectInspector); + } + sb.append(SerDeUtils.RBRACKET); + } + } + + private static void serializeMap(@Nonnull final StringBuilder sb, @Nullable final Object obj, + @Nonnull final MapObjectInspector moi) throws SerDeException { + ObjectInspector mapKeyObjectInspector = moi.getMapKeyObjectInspector(); + ObjectInspector mapValueObjectInspector = moi.getMapValueObjectInspector(); + Map<?, ?> omap = moi.getMap(obj); + if (omap == null) { + sb.append("null"); + } else { + sb.append(SerDeUtils.LBRACE); + boolean first = true; + for (Object entry : omap.entrySet()) { + if (first) { + first = false; + } else { + sb.append(SerDeUtils.COMMA); + } + Map.Entry<?, ?> e = (Map.Entry<?, ?>) entry; + StringBuilder keyBuilder = new StringBuilder(); + buildJSONString(keyBuilder, e.getKey(), mapKeyObjectInspector); + String keyString = keyBuilder.toString().trim(); + if ((!keyString.isEmpty()) && (keyString.charAt(0) != SerDeUtils.QUOTE)) { + appendWithQuotes(sb, keyString); + } else { + sb.append(keyString); + } + sb.append(SerDeUtils.COLON); + buildJSONString(sb, e.getValue(), mapValueObjectInspector); + } + sb.append(SerDeUtils.RBRACE); + } + } + + private static void serializePrimitive(@Nonnull final StringBuilder sb, + @Nullable final Object obj, @Nullable final PrimitiveObjectInspector poi) + throws SerDeException { + if (obj == null) { + sb.append("null"); + } else { + switch (poi.getPrimitiveCategory()) { + case BOOLEAN: { + boolean b = ((BooleanObjectInspector) poi).get(obj); + sb.append(b ? "true" : "false"); + break; + } + case BYTE: { + sb.append(((ByteObjectInspector) poi).get(obj)); + break; + } + case SHORT: { + sb.append(((ShortObjectInspector) poi).get(obj)); + break; + } + case INT: { + sb.append(((IntObjectInspector) poi).get(obj)); + break; + } + case LONG: { + sb.append(((LongObjectInspector) poi).get(obj)); + break; + } + case FLOAT: { + sb.append(((FloatObjectInspector) poi).get(obj)); + break; + } + case DOUBLE: { + sb.append(((DoubleObjectInspector) poi).get(obj)); + break; + } + case STRING: { + String s = SerDeUtils.escapeString( + ((StringObjectInspector) poi).getPrimitiveJavaObject(obj)); + appendWithQuotes(sb, s); + break; + } + case BINARY: + byte[] b = ((BinaryObjectInspector) poi).getPrimitiveJavaObject(obj); + Text txt = new Text(); + txt.set(b, 0, b.length); + appendWithQuotes(sb, SerDeUtils.escapeString(txt.toString())); + break; + case DATE: + Date d = ((DateObjectInspector) poi).getPrimitiveJavaObject(obj); + appendWithQuotes(sb, d.toString()); + break; + case TIMESTAMP: { + Timestamp t = ((TimestampObjectInspector) poi).getPrimitiveJavaObject(obj); + appendWithQuotes(sb, t.toString()); + break; + } + case DECIMAL: + sb.append(((HiveDecimalObjectInspector) poi).getPrimitiveJavaObject(obj)); + break; + case VARCHAR: { + String s = SerDeUtils.escapeString( + ((HiveVarcharObjectInspector) poi).getPrimitiveJavaObject(obj).toString()); + appendWithQuotes(sb, s); + break; + } + case CHAR: { + //this should use HiveChar.getPaddedValue() but it's protected; currently (v0.13) + // HiveChar.toString() returns getPaddedValue() + String s = SerDeUtils.escapeString( + ((HiveCharObjectInspector) poi).getPrimitiveJavaObject(obj).toString()); + appendWithQuotes(sb, s); + break; + } + default: + throw new SerDeException( + "Unknown primitive type: " + poi.getPrimitiveCategory()); + } + } + } + + private static void buildJSONString(@Nonnull final StringBuilder sb, @Nullable final Object obj, + @Nonnull final ObjectInspector oi) throws SerDeException { + switch (oi.getCategory()) { + case PRIMITIVE: { + PrimitiveObjectInspector poi = (PrimitiveObjectInspector) oi; + serializePrimitive(sb, obj, poi); + break; + } + case LIST: { + ListObjectInspector loi = (ListObjectInspector) oi; + serializeList(sb, obj, loi); + break; + } + case MAP: { + MapObjectInspector moi = (MapObjectInspector) oi; + serializeMap(sb, obj, moi); + break; + } + case STRUCT: { + StructObjectInspector soi = (StructObjectInspector) oi; + serializeStruct(sb, obj, soi, null); + break; + } + case UNION: { + UnionObjectInspector uoi = (UnionObjectInspector) oi; + if (obj == null) { + sb.append("null"); + } else { + sb.append(SerDeUtils.LBRACE); + sb.append(uoi.getTag(obj)); + sb.append(SerDeUtils.COLON); + buildJSONString(sb, uoi.getField(obj), + uoi.getObjectInspectors().get(uoi.getTag(obj))); + sb.append(SerDeUtils.RBRACE); + } + break; + } + default: + throw new SerDeException("Unknown type in ObjectInspector: " + oi.getCategory()); + } + } + + @Nonnull + public static <T> T deserialize(@Nonnull final Text t) throws SerDeException { + return deserialize(t, null, null); + } + + /** + * Deserialize Json array or Json primitives. + */ + @Nonnull + public static <T> T deserialize(@Nonnull final Text t, @Nonnull TypeInfo columnTypes) + throws SerDeException { + return deserialize(t, null, Arrays.asList(columnTypes)); + } + + @SuppressWarnings("unchecked") + @Nonnull + public static <T> T deserialize(@Nonnull final Text t, @Nullable final List<String> columnNames, + @Nullable final List<TypeInfo> columnTypes) throws SerDeException { + final Object result; + try { + JsonParser p = + new JsonFactory().createJsonParser(new FastByteArrayInputStream(t.getBytes())); + final JsonToken token = p.nextToken(); + if (token == JsonToken.START_OBJECT) { + result = parseObject(p, columnNames, columnTypes); + } else if (token == JsonToken.START_ARRAY) { + result = parseArray(p, columnTypes); + } else { + result = parseValue(p); + } + } catch (JsonParseException e) { + throw new SerDeException(e); + } catch (IOException e) { + throw new SerDeException(e); + } + return (T) result; + } + + @Nonnull + private static Object parseObject(@Nonnull final JsonParser p, + @CheckForNull final List<String> columnNames, + @CheckForNull final List<TypeInfo> columnTypes) + throws JsonParseException, IOException, SerDeException { + Preconditions.checkNotNull(columnNames, "columnNames MUST NOT be null in parseObject", + SerDeException.class); + Preconditions.checkNotNull(columnTypes, "columnTypes MUST NOT be null in parseObject", + SerDeException.class); + if (columnNames.size() != columnTypes.size()) { + throw new SerDeException( + "Size of columnNames and columnTypes does not match. #columnNames=" + + columnNames.size() + ", #columnTypes=" + columnTypes.size()); + } + + TypeInfo rowTypeInfo = TypeInfoFactory.getStructTypeInfo(columnNames, columnTypes); + final HCatSchema schema; + try { + schema = HCatSchemaUtils.getHCatSchema(rowTypeInfo).get(0).getStructSubSchema(); + } catch (HCatException e) { + throw new SerDeException(e); + } + + final List<Object> r = new ArrayList<Object>(Collections.nCopies(columnNames.size(), null)); + JsonToken token; + while (((token = p.nextToken()) != JsonToken.END_OBJECT) && (token != null)) { + // iterate through each token, and create appropriate object here. + populateRecord(r, token, p, schema); + } + + if (columnTypes.size() == 1) { + return r.get(0); + } + return r; + } + + @Nonnull + private static List<Object> parseArray(@Nonnull final JsonParser p, + @CheckForNull final List<TypeInfo> columnTypes) + throws HCatException, IOException, SerDeException { + Preconditions.checkNotNull(columnTypes, "columnTypes MUST NOT be null", + SerDeException.class); + if (columnTypes.size() != 1) { + throw new IOException("Expected a single array but go " + columnTypes); + } + + TypeInfo elemType = columnTypes.get(0); + HCatSchema schema = HCatSchemaUtils.getHCatSchema(elemType); + + HCatFieldSchema listSchema = schema.get(0); + HCatFieldSchema elemSchema = listSchema.getArrayElementSchema().get(0); + + final List<Object> arr = new ArrayList<Object>(); + while (p.nextToken() != JsonToken.END_ARRAY) { + arr.add(extractCurrentField(p, elemSchema, true)); + } + return arr; + } + + @Nonnull + private static Object parseValue(@Nonnull final JsonParser p) + throws JsonParseException, IOException { + final JsonToken t = p.getCurrentToken(); + switch (t) { + case VALUE_FALSE: + return Boolean.FALSE; + case VALUE_TRUE: + return Boolean.TRUE; + case VALUE_NULL: + return null; + case VALUE_STRING: + return p.getText(); + case VALUE_NUMBER_FLOAT: + return p.getDoubleValue(); + case VALUE_NUMBER_INT: + return p.getIntValue(); + default: + throw new IOException("Unexpected token: " + t); + } + } + + private static void populateRecord(@Nonnull final List<Object> r, + @Nonnull final JsonToken token, @Nonnull final JsonParser p, + @Nonnull final HCatSchema s) throws IOException { + if (token != JsonToken.FIELD_NAME) { + throw new IOException("Field name expected"); + } + String fieldName = p.getText(); + Integer fpos = s.getPosition(fieldName); + if (fpos == null) { + fpos = getPositionFromHiveInternalColumnName(fieldName); + if (fpos == -1) { + skipValue(p); + return; // unknown field, we return. We'll continue from the next field onwards. + } + // If we get past this, then the column name did match the hive pattern for an internal + // column name, such as _col0, etc, so it *MUST* match the schema for the appropriate column. + // This means people can't use arbitrary column names such as _col0, and expect us to ignore it + // if we find it. + if (!fieldName.equalsIgnoreCase(getHiveInternalColumnName(fpos))) { + throw new IOException("Hive internal column name (" + fieldName + + ") and position encoding (" + fpos + ") for the column name are at odds"); + } + // If we reached here, then we were successful at finding an alternate internal + // column mapping, and we're about to proceed. + } + HCatFieldSchema hcatFieldSchema = s.getFields().get(fpos); + Object currField = extractCurrentField(p, hcatFieldSchema, false); + r.set(fpos, currField); + } + + @SuppressWarnings("deprecation") + @Nullable + private static Object extractCurrentField(@Nonnull final JsonParser p, + @Nonnull final HCatFieldSchema hcatFieldSchema, final boolean isTokenCurrent) + throws IOException { + JsonToken valueToken; + if (isTokenCurrent) { + valueToken = p.getCurrentToken(); + } else { + valueToken = p.nextToken(); + } + + final Object val; + switch (hcatFieldSchema.getType()) { + case INT: + val = (valueToken == JsonToken.VALUE_NULL) ? null : p.getIntValue(); + break; + case TINYINT: + val = (valueToken == JsonToken.VALUE_NULL) ? null : p.getByteValue(); + break; + case SMALLINT: + val = (valueToken == JsonToken.VALUE_NULL) ? null : p.getShortValue(); + break; + case BIGINT: + val = (valueToken == JsonToken.VALUE_NULL) ? null : p.getLongValue(); + break; + case BOOLEAN: + String bval = (valueToken == JsonToken.VALUE_NULL) ? null : p.getText(); + if (bval != null) { + val = Boolean.valueOf(bval); + } else { + val = null; + } + break; + case FLOAT: + val = (valueToken == JsonToken.VALUE_NULL) ? null : p.getFloatValue(); + break; + case DOUBLE: + val = (valueToken == JsonToken.VALUE_NULL) ? null : p.getDoubleValue(); + break; + case STRING: + val = (valueToken == JsonToken.VALUE_NULL) ? null : p.getText(); + break; + case BINARY: + String b = (valueToken == JsonToken.VALUE_NULL) ? null : p.getText(); + if (b != null) { + try { + String t = Text.decode(b.getBytes(), 0, b.getBytes().length); + return t.getBytes(); + } catch (CharacterCodingException e) { + throw new IOException("Error generating json binary type from object.", e); + } + } else { + val = null; + } + break; + case DATE: + val = (valueToken == JsonToken.VALUE_NULL) ? null : Date.valueOf(p.getText()); + break; + case TIMESTAMP: + val = (valueToken == JsonToken.VALUE_NULL) ? null : Timestamp.valueOf(p.getText()); + break; + case DECIMAL: + val = (valueToken == JsonToken.VALUE_NULL) ? null : HiveDecimal.create(p.getText()); + break; + case VARCHAR: + int vLen = ((BaseCharTypeInfo) hcatFieldSchema.getTypeInfo()).getLength(); + val = (valueToken == JsonToken.VALUE_NULL) ? null + : new HiveVarchar(p.getText(), vLen); + break; + case CHAR: + int cLen = ((BaseCharTypeInfo) hcatFieldSchema.getTypeInfo()).getLength(); + val = (valueToken == JsonToken.VALUE_NULL) ? null : new HiveChar(p.getText(), cLen); + break; + case ARRAY: + if (valueToken == JsonToken.VALUE_NULL) { + val = null; + break; + } + if (valueToken != JsonToken.START_ARRAY) { + throw new IOException("Start of Array expected"); + } + final List<Object> arr = new ArrayList<>(); + final HCatFieldSchema elemSchema = hcatFieldSchema.getArrayElementSchema().get(0); + while ((valueToken = p.nextToken()) != JsonToken.END_ARRAY) { + arr.add(extractCurrentField(p, elemSchema, true)); + } + val = arr; + break; + case MAP: + if (valueToken == JsonToken.VALUE_NULL) { + val = null; + break; + } + if (valueToken != JsonToken.START_OBJECT) { + throw new IOException("Start of Object expected"); + } + final Map<Object, Object> map = new LinkedHashMap<>(); + final HCatFieldSchema valueSchema = hcatFieldSchema.getMapValueSchema().get(0); + while ((valueToken = p.nextToken()) != JsonToken.END_OBJECT) { + Object k = getObjectOfCorrespondingPrimitiveType(p.getCurrentName(), + hcatFieldSchema.getMapKeyTypeInfo()); + Object v = extractCurrentField(p, valueSchema, false); + map.put(k, v); + } + val = map; + break; + case STRUCT: + if (valueToken == JsonToken.VALUE_NULL) { + val = null; + break; + } + if (valueToken != JsonToken.START_OBJECT) { + throw new IOException("Start of Object expected"); + } + HCatSchema subSchema = hcatFieldSchema.getStructSubSchema(); + int sz = subSchema.getFieldNames().size(); + + List<Object> struct = new ArrayList<>(Collections.nCopies(sz, null)); + while ((valueToken = p.nextToken()) != JsonToken.END_OBJECT) { + populateRecord(struct, valueToken, p, subSchema); + } + val = struct; + break; + default: + throw new IOException("Unknown type found: " + hcatFieldSchema.getType()); + } + return val; + } + + @Nonnull + private static Object getObjectOfCorrespondingPrimitiveType(String s, + PrimitiveTypeInfo mapKeyType) throws IOException { + switch (Type.getPrimitiveHType(mapKeyType)) { + case INT: + return Integer.valueOf(s); + case TINYINT: + return Byte.valueOf(s); + case SMALLINT: + return Short.valueOf(s); + case BIGINT: + return Long.valueOf(s); + case BOOLEAN: + return (s.equalsIgnoreCase("true")); + case FLOAT: + return Float.valueOf(s); + case DOUBLE: + return Double.valueOf(s); + case STRING: + return s; + case BINARY: + try { + String t = Text.decode(s.getBytes(), 0, s.getBytes().length); + return t.getBytes(); + } catch (CharacterCodingException e) { + throw new IOException("Error generating json binary type from object.", e); + } + case DATE: + return Date.valueOf(s); + case TIMESTAMP: + return Timestamp.valueOf(s); + case DECIMAL: + return HiveDecimal.create(s); + case VARCHAR: + return new HiveVarchar(s, ((BaseCharTypeInfo) mapKeyType).getLength()); + case CHAR: + return new HiveChar(s, ((BaseCharTypeInfo) mapKeyType).getLength()); + default: + throw new IOException( + "Could not convert from string to map type " + mapKeyType.getTypeName()); + } + } + + private static int getPositionFromHiveInternalColumnName(String internalName) { + Pattern internalPattern = Pattern.compile("_col([0-9]+)"); + Matcher m = internalPattern.matcher(internalName); + if (!m.matches()) { + return -1; + } else { + return Integer.parseInt(m.group(1)); + } + } + + private static void skipValue(@Nonnull final JsonParser p) + throws JsonParseException, IOException { + JsonToken valueToken = p.nextToken(); + if ((valueToken == JsonToken.START_ARRAY) || (valueToken == JsonToken.START_OBJECT)) { + // if the currently read token is a beginning of an array or object, move stream forward + // skipping any child tokens till we're at the corresponding END_ARRAY or END_OBJECT token + p.skipChildren(); + } + } + + @Nonnull + private static String getHiveInternalColumnName(int fpos) { + return HiveConf.getColumnInternalName(fpos); + } + + @Nonnull + private static StringBuilder appendWithQuotes(@Nonnull final StringBuilder sb, + @Nonnull final String value) { + return sb.append(SerDeUtils.QUOTE).append(value).append(SerDeUtils.QUOTE); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java b/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java index a9c7390..280107d 100644 --- a/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/WritableUtils.java @@ -109,6 +109,24 @@ public final class WritableUtils { } @Nonnull + public static List<IntWritable> toWritableList(@Nonnull final int[] src) { + // workaround to avoid a bug in Kryo + // https://issues.apache.org/jira/browse/HIVE-12551 + /* + final LongWritable[] writables = new LongWritable[src.length]; + for (int i = 0; i < src.length; i++) { + writables[i] = new LongWritable(src[i]); + } + return Arrays.asList(writables); + */ + final List<IntWritable> list = new ArrayList<IntWritable>(src.length); + for (int i = 0; i < src.length; i++) { + list.add(new IntWritable(src[i])); + } + return list; + } + + @Nonnull public static List<LongWritable> toWritableList(@Nonnull final long[] src) { // workaround to avoid a bug in Kryo // https://issues.apache.org/jira/browse/HIVE-12551 @@ -127,6 +145,15 @@ public final class WritableUtils { } @Nonnull + public static List<FloatWritable> toWritableList(@Nonnull final float[] src) { + final List<FloatWritable> list = new ArrayList<FloatWritable>(src.length); + for (int i = 0; i < src.length; i++) { + list.add(new FloatWritable(src[i])); + } + return list; + } + + @Nonnull public static List<DoubleWritable> toWritableList(@Nonnull final double[] src) { // workaround to avoid a bug in Kryo // https://issues.apache.org/jira/browse/HIVE-12551 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/utils/hashing/HashFunctionFactory.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hashing/HashFunctionFactory.java b/core/src/main/java/hivemall/utils/hashing/HashFunctionFactory.java index bc4339f..62fe67b 100644 --- a/core/src/main/java/hivemall/utils/hashing/HashFunctionFactory.java +++ b/core/src/main/java/hivemall/utils/hashing/HashFunctionFactory.java @@ -18,7 +18,6 @@ */ package hivemall.utils.hashing; - import java.util.Random; public final class HashFunctionFactory { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/utils/lang/StringUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/StringUtils.java b/core/src/main/java/hivemall/utils/lang/StringUtils.java index 3652ebd..b83940b 100644 --- a/core/src/main/java/hivemall/utils/lang/StringUtils.java +++ b/core/src/main/java/hivemall/utils/lang/StringUtils.java @@ -193,7 +193,6 @@ public final class StringUtils { if (i > 0) { // append separator before each element, except for the head element buf.append(sep); } - final String s = list.get(i); if (s != null) { buf.append(s); @@ -202,6 +201,21 @@ public final class StringUtils { return buf.toString(); } + @Nonnull + public static String join(@Nonnull final Object[] list, @Nonnull final char sep) { + final StringBuilder buf = new StringBuilder(128); + for (int i = 0; i < list.length; i++) { + if (i > 0) { // append separator before each element, except for the head element + buf.append(sep); + } + final Object s = list[i]; + if (s != null) { + buf.append(s); + } + } + return buf.toString(); + } + @Nullable public static String[] split(@Nullable final String str, final char separatorChar) { return split(str, separatorChar, false); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/utils/stats/MovingAverage.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/stats/MovingAverage.java b/core/src/main/java/hivemall/utils/stats/MovingAverage.java new file mode 100644 index 0000000..0a2cbf6 --- /dev/null +++ b/core/src/main/java/hivemall/utils/stats/MovingAverage.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.stats; + +import hivemall.utils.collections.DoubleRingBuffer; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Preconditions; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class MovingAverage { + + @Nonnull + private final DoubleRingBuffer ring; + + private double totalSum; + + public MovingAverage(@Nonnegative int windowSize) { + Preconditions.checkArgument(windowSize > 1, "Invalid window size: " + windowSize); + this.ring = new DoubleRingBuffer(windowSize); + this.totalSum = 0.d; + } + + public double add(final double x) { + if (!NumberUtils.isFinite(x)) { + throw new IllegalArgumentException("Detected Infinite input: " + x); + } + + if (ring.isFull()) { + double head = ring.head(); + this.totalSum -= head; + } + ring.add(x); + totalSum += x; + + final int size = ring.size(); + if (size == 0) { + return 0.d; + } + return totalSum / size; + } + + public double get() { + final int size = ring.size(); + if (size == 0) { + return 0.d; + } + return totalSum / size; + } + + @Override + public String toString() { + return "MovingAverage [ring=" + ring + ", total=" + totalSum + ", moving_avg=" + get() + + "]"; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/main/java/hivemall/utils/stats/OnlineVariance.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/stats/OnlineVariance.java b/core/src/main/java/hivemall/utils/stats/OnlineVariance.java new file mode 100644 index 0000000..92a85f6 --- /dev/null +++ b/core/src/main/java/hivemall/utils/stats/OnlineVariance.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.utils.stats; + +/** + * @see http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + */ +public final class OnlineVariance { + + private long n; + private double mean; + private double m2; + + public OnlineVariance() { + reset(); + } + + public void reset() { + this.n = 0L; + this.mean = 0.d; + this.m2 = 0.d; + } + + public void handle(double x) { + ++n; + double delta = x - mean; + mean += delta / n; + m2 += delta * (x - mean); + } + + public void unhandle(double x) { + if (n == 0L) { + return; // nop + } + if (n == 1L) { + reset(); + return; + } + double old_mean = (n * mean - x) / (n - 1L); + m2 -= (x - mean) * (x - old_mean); + mean = old_mean; + --n; + } + + public long numSamples() { + return n; + } + + public double mean() { + return mean; + } + + public double variance() { + return n > 1 ? (m2 / (n - 1)) : 0.d; + } + + public double stddev() { + return Math.sqrt(variance()); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/test/java/hivemall/common/OnlineVarianceTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/common/OnlineVarianceTest.java b/core/src/test/java/hivemall/common/OnlineVarianceTest.java deleted file mode 100644 index 2308dea..0000000 --- a/core/src/test/java/hivemall/common/OnlineVarianceTest.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.common; - -import java.util.Collections; -import java.util.ArrayList; -import java.util.Random; - -import static org.junit.Assert.assertEquals; - -import org.junit.Test; - -public class OnlineVarianceTest { - - @Test - public void testSimple() { - OnlineVariance onlineVariance = new OnlineVariance(); - - long n = 0L; - double sum = 0.d; - double sumOfSquare = 0.d; - - assertEquals(0L, onlineVariance.numSamples()); - assertEquals(0.d, onlineVariance.mean(), 1e-5f); - assertEquals(0.d, onlineVariance.variance(), 1e-5f); - assertEquals(0.d, onlineVariance.stddev(), 1e-5f); - - Random rnd = new Random(); - ArrayList<Double> dArrayList = new ArrayList<Double>(); - - for (int i = 0; i < 10; i++) { - double x = rnd.nextDouble(); - dArrayList.add(x); - onlineVariance.handle(x); - - n++; - sum += x; - sumOfSquare += x * x; - - double mean = n > 0 ? (sum / n) : 0.d; - double sampleVariance = n > 0 ? ((sumOfSquare / n) - mean * mean) : 0.d; - double unbiasedVariance = n > 1 ? (sampleVariance * n / (n - 1)) : 0.d; - double stddev = Math.sqrt(unbiasedVariance); - - assertEquals(n, onlineVariance.numSamples()); - assertEquals(mean, onlineVariance.mean(), 1e-5f); - assertEquals(unbiasedVariance, onlineVariance.variance(), 1e-5f); - assertEquals(stddev, onlineVariance.stddev(), 1e-5f); - } - - Collections.shuffle(dArrayList); - - for (Double x : dArrayList) { - onlineVariance.unhandle(x.doubleValue()); - - n--; - sum -= x; - sumOfSquare -= x * x; - - double mean = n > 0 ? (sum / n) : 0.d; - double sampleVariance = n > 0 ? ((sumOfSquare / n) - mean * mean) : 0.d; - double unbiasedVariance = n > 1 ? (sampleVariance * n / (n - 1)) : 0.d; - double stddev = Math.sqrt(unbiasedVariance); - - assertEquals(n, onlineVariance.numSamples()); - assertEquals(mean, onlineVariance.mean(), 1e-5f); - assertEquals(unbiasedVariance, onlineVariance.variance(), 1e-5f); - assertEquals(stddev, onlineVariance.stddev(), 1e-5f); - } - - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/test/java/hivemall/sketch/bloom/BloomAndUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/sketch/bloom/BloomAndUDFTest.java b/core/src/test/java/hivemall/sketch/bloom/BloomAndUDFTest.java new file mode 100644 index 0000000..97ad7c6 --- /dev/null +++ b/core/src/test/java/hivemall/sketch/bloom/BloomAndUDFTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.sketch.bloom; + +import java.io.IOException; +import java.util.Random; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.util.bloom.DynamicBloomFilter; +import org.apache.hadoop.util.bloom.Filter; +import org.apache.hadoop.util.bloom.Key; +import org.junit.Assert; +import org.junit.Test; + +public class BloomAndUDFTest { + + @Test + public void test() throws IOException, HiveException { + BloomAndUDF udf = new BloomAndUDF(); + + DynamicBloomFilter bf1 = createBloomFilter(1L, 10000); + DynamicBloomFilter bf2 = createBloomFilter(2L, 10000); + + Text bf1str = BloomFilterUtils.serialize(bf1, new Text()); + Text bf2str = BloomFilterUtils.serialize(bf2, new Text()); + + bf1.and(bf2); + Text expected = BloomFilterUtils.serialize(bf1, new Text()); + + Text actual = udf.evaluate(bf1str, bf2str); + + Assert.assertEquals(expected, actual); + + DynamicBloomFilter deserialized = + BloomFilterUtils.deserialize(actual, new DynamicBloomFilter()); + assertNotContains(bf1, deserialized, 1L, 10000); + assertNotContains(bf1, deserialized, 2L, 10000); + } + + @Nonnull + private static DynamicBloomFilter createBloomFilter(long seed, int size) { + DynamicBloomFilter dbf = BloomFilterUtils.newDynamicBloomFilter(3000); + final Key key = new Key(); + + final Random rnd1 = new Random(seed); + for (int i = 0; i < size; i++) { + double d = rnd1.nextGaussian(); + String s = Double.toHexString(d); + + key.set(s.getBytes(), 1.0); + dbf.add(key); + } + + return dbf; + } + + private static void assertNotContains(@Nonnull Filter expected, @Nonnull Filter actual, + long seed, int size) { + final Key key = new Key(); + + final Random rnd1 = new Random(seed); + for (int i = 0; i < size; i++) { + double d = rnd1.nextGaussian(); + String s = Double.toHexString(d); + key.set(s.getBytes(), 1.0); + Assert.assertEquals(expected.membershipTest(key), actual.membershipTest(key)); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/test/java/hivemall/sketch/bloom/BloomContainsUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/sketch/bloom/BloomContainsUDFTest.java b/core/src/test/java/hivemall/sketch/bloom/BloomContainsUDFTest.java new file mode 100644 index 0000000..58e0db9 --- /dev/null +++ b/core/src/test/java/hivemall/sketch/bloom/BloomContainsUDFTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.sketch.bloom; + +import java.io.IOException; +import java.util.Random; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.util.bloom.DynamicBloomFilter; +import org.apache.hadoop.util.bloom.Key; +import org.junit.Assert; +import org.junit.Test; + +public class BloomContainsUDFTest { + + @Test + public void test() throws IOException, HiveException { + BloomContainsUDF udf = new BloomContainsUDF(); + final long seed = 43L; + final int size = 100; + + DynamicBloomFilter dbf = createBloomFilter(seed, size); + Text bfstr = BloomFilterUtils.serialize(dbf, new Text()); + + final Text key = new Text(); + final Random rnd1 = new Random(seed); + for (int i = 0; i < size; i++) { + double d = rnd1.nextGaussian(); + String s = Double.toHexString(d); + key.set(s); + Assert.assertEquals(Boolean.TRUE, udf.evaluate(bfstr, key)); + } + } + + @Nonnull + private static DynamicBloomFilter createBloomFilter(long seed, int size) { + DynamicBloomFilter dbf = BloomFilterUtils.newDynamicBloomFilter(30); + final Key key = new Key(); + + final Random rnd1 = new Random(seed); + for (int i = 0; i < size; i++) { + double d = rnd1.nextGaussian(); + String s = Double.toHexString(d); + Text t = new Text(s); + key.set(t.getBytes(), 1.0); + dbf.add(key); + } + + return dbf; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/test/java/hivemall/sketch/bloom/BloomFilterUtilsTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/sketch/bloom/BloomFilterUtilsTest.java b/core/src/test/java/hivemall/sketch/bloom/BloomFilterUtilsTest.java new file mode 100644 index 0000000..82d8908 --- /dev/null +++ b/core/src/test/java/hivemall/sketch/bloom/BloomFilterUtilsTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.sketch.bloom; + +import java.io.IOException; +import java.util.Random; + +import org.apache.hadoop.util.bloom.DynamicBloomFilter; +import org.apache.hadoop.util.bloom.Key; +import org.junit.Assert; +import org.junit.Test; + +public class BloomFilterUtilsTest { + + @Test + public void testDynamicBloomFilter() { + DynamicBloomFilter dbf = BloomFilterUtils.newDynamicBloomFilter(300000); + final Key key = new Key(); + + final Random rnd1 = new Random(43L); + for (int i = 0; i < 1000000; i++) { + double d = rnd1.nextGaussian(); + String s = Double.toHexString(d); + key.set(s.getBytes(), 1.0); + dbf.add(key); + } + + final Random rnd2 = new Random(43L); + for (int i = 0; i < 1000000; i++) { + double d = rnd2.nextGaussian(); + String s = Double.toHexString(d); + key.set(s.getBytes(), 1.0); + Assert.assertTrue(dbf.membershipTest(key)); + } + } + + @Test + public void testDynamicBloomFilterSerde() throws IOException { + final Key key = new Key(); + + DynamicBloomFilter dbf1 = BloomFilterUtils.newDynamicBloomFilter(300000); + final Random rnd1 = new Random(43L); + for (int i = 0; i < 1000000; i++) { + double d = rnd1.nextGaussian(); + String s = Double.toHexString(d); + key.set(s.getBytes(), 1.0); + dbf1.add(key); + } + + DynamicBloomFilter dbf2 = BloomFilterUtils.deserialize(BloomFilterUtils.serialize(dbf1), + new DynamicBloomFilter()); + final Random rnd2 = new Random(43L); + for (int i = 0; i < 1000000; i++) { + double d = rnd2.nextGaussian(); + String s = Double.toHexString(d); + key.set(s.getBytes(), 1.0); + Assert.assertTrue(dbf2.membershipTest(key)); + } + } + + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/test/java/hivemall/sketch/bloom/BloomNotUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/sketch/bloom/BloomNotUDFTest.java b/core/src/test/java/hivemall/sketch/bloom/BloomNotUDFTest.java new file mode 100644 index 0000000..3f6b8eb --- /dev/null +++ b/core/src/test/java/hivemall/sketch/bloom/BloomNotUDFTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.sketch.bloom; + +import java.io.IOException; +import java.util.Random; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.util.bloom.DynamicBloomFilter; +import org.apache.hadoop.util.bloom.Key; +import org.junit.Assert; +import org.junit.Test; + +public class BloomNotUDFTest { + + @Test + public void test() throws IOException, HiveException { + BloomNotUDF udf = new BloomNotUDF(); + + DynamicBloomFilter bf1 = createBloomFilter(1L, 10000); + Text bf1str = BloomFilterUtils.serialize(bf1, new Text()); + + Text result = udf.evaluate(bf1str); + DynamicBloomFilter actual = BloomFilterUtils.deserialize(result, new DynamicBloomFilter()); + + bf1.not(); + + Assert.assertEquals(bf1.toString(), actual.toString()); + } + + @Nonnull + private static DynamicBloomFilter createBloomFilter(long seed, int size) { + DynamicBloomFilter dbf = BloomFilterUtils.newDynamicBloomFilter(3000); + final Key key = new Key(); + + final Random rnd1 = new Random(seed); + for (int i = 0; i < size; i++) { + double d = rnd1.nextGaussian(); + String s = Double.toHexString(d); + + key.set(s.getBytes(), 1.0); + dbf.add(key); + } + + return dbf; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/1e1b77ea/core/src/test/java/hivemall/sketch/bloom/BloomOrUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/sketch/bloom/BloomOrUDFTest.java b/core/src/test/java/hivemall/sketch/bloom/BloomOrUDFTest.java new file mode 100644 index 0000000..64f95e0 --- /dev/null +++ b/core/src/test/java/hivemall/sketch/bloom/BloomOrUDFTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.sketch.bloom; + +import java.io.IOException; +import java.util.Random; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.util.bloom.DynamicBloomFilter; +import org.apache.hadoop.util.bloom.Filter; +import org.apache.hadoop.util.bloom.Key; +import org.junit.Assert; +import org.junit.Test; + +public class BloomOrUDFTest { + + @Test + public void test() throws IOException, HiveException { + BloomOrUDF udf = new BloomOrUDF(); + + DynamicBloomFilter bf1 = createBloomFilter(1L, 10000); + DynamicBloomFilter bf2 = createBloomFilter(2L, 10000); + + Text bf1str = BloomFilterUtils.serialize(bf1, new Text()); + Text bf2str = BloomFilterUtils.serialize(bf2, new Text()); + + bf1.or(bf2); + Text expected = BloomFilterUtils.serialize(bf1, new Text()); + + Text actual = udf.evaluate(bf1str, bf2str); + + Assert.assertEquals(expected, actual); + + DynamicBloomFilter deserialized = + BloomFilterUtils.deserialize(actual, new DynamicBloomFilter()); + assertEquals(bf1, deserialized, 1L, 10000); + assertEquals(bf1, deserialized, 2L, 10000); + } + + @Nonnull + private static DynamicBloomFilter createBloomFilter(long seed, int size) { + DynamicBloomFilter dbf = BloomFilterUtils.newDynamicBloomFilter(3000); + final Key key = new Key(); + + final Random rnd1 = new Random(seed); + for (int i = 0; i < size; i++) { + double d = rnd1.nextGaussian(); + String s = Double.toHexString(d); + + key.set(s.getBytes(), 1.0); + dbf.add(key); + } + + return dbf; + } + + private static void assertEquals(@Nonnull Filter expected, @Nonnull Filter actual, long seed, + int size) { + final Key key = new Key(); + + final Random rnd1 = new Random(seed); + for (int i = 0; i < size; i++) { + double d = rnd1.nextGaussian(); + String s = Double.toHexString(d); + key.set(s.getBytes(), 1.0); + Assert.assertEquals(expected.membershipTest(key), actual.membershipTest(key)); + } + } + +}