Implement initial SST-based change-point detector
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/3ebd771e Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/3ebd771e Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/3ebd771e Branch: refs/heads/JIRA-22/pr-356 Commit: 3ebd771ee4bebf14769b7c240f8b28b9d5d10e86 Parents: 89ec56e Author: Takuya Kitazawa <k.tak...@gmail.com> Authored: Mon Sep 26 17:12:01 2016 +0900 Committer: Takuya Kitazawa <k.tak...@gmail.com> Committed: Mon Sep 26 17:12:01 2016 +0900 ---------------------------------------------------------------------- .../java/hivemall/anomaly/SSTChangePoint.java | 118 +++++++++++ .../hivemall/anomaly/SSTChangePointUDF.java | 197 +++++++++++++++++++ .../hivemall/anomaly/SSTChangePointTest.java | 111 +++++++++++ 3 files changed, 426 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3ebd771e/core/src/main/java/hivemall/anomaly/SSTChangePoint.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/anomaly/SSTChangePoint.java b/core/src/main/java/hivemall/anomaly/SSTChangePoint.java new file mode 100644 index 0000000..e693bd4 --- /dev/null +++ b/core/src/main/java/hivemall/anomaly/SSTChangePoint.java @@ -0,0 +1,118 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.anomaly; + +import hivemall.anomaly.SSTChangePointUDF.SSTChangePointInterface; +import hivemall.anomaly.SSTChangePointUDF.Parameters; +import hivemall.utils.collections.DoubleRingBuffer; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.SingularValueDecomposition; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; + +import java.util.Arrays; + +import javax.annotation.Nonnull; + +final class SSTChangePoint implements SSTChangePointInterface { + + @Nonnull + private final PrimitiveObjectInspector oi; + + @Nonnull + private final int window; + @Nonnull + private final int nPastWindow; + @Nonnull + private final int nCurrentWindow; + @Nonnull + private final int pastSize; + @Nonnull + private final int currentSize; + @Nonnull + private final int currentOffset; + @Nonnull + private final int r; + + @Nonnull + private final DoubleRingBuffer xRing; + @Nonnull + private final double[] xSeries; + + SSTChangePoint(@Nonnull Parameters params, @Nonnull PrimitiveObjectInspector oi) { + this.oi = oi; + + this.window = params.w; + this.nPastWindow = params.n; + this.nCurrentWindow = params.m; + this.pastSize = window + nPastWindow; + this.currentSize = window + nCurrentWindow; + this.currentOffset = params.g; + this.r = params.r; + + // (w + n) past samples for the n-past-windows + // (w + m) current samples for the m-current-windows, starting from offset g + // => need to hold past (w + n + g + w + m) samples from the latest sample + int holdSampleSize = pastSize + currentOffset + currentSize; + + this.xRing = new DoubleRingBuffer(holdSampleSize); + this.xSeries = new double[holdSampleSize]; + } + + @Override + public void update(@Nonnull final Object arg, @Nonnull final double[] outScores) + throws HiveException { + double x = PrimitiveObjectInspectorUtils.getDouble(arg, oi); + xRing.add(x).toArray(xSeries, true /* FIFO */); + + // need to wait until the buffer is filled + if (!xRing.isFull()) { + outScores[0] = 0.d; + } else { + outScores[0] = computeScore(); + } + } + + private double computeScore() { + // create past trajectory matrix and find its left singular vectors + RealMatrix H = MatrixUtils.createRealMatrix(window, nPastWindow); + for (int i = 0; i < nPastWindow; i++) { + H.setColumn(i, Arrays.copyOfRange(xSeries, i, i + window)); + } + SingularValueDecomposition svdH = new SingularValueDecomposition(H); + RealMatrix UT = svdH.getUT(); + + // create current trajectory matrix and find its left singular vectors + RealMatrix G = MatrixUtils.createRealMatrix(window, nCurrentWindow); + int currentHead = pastSize + currentOffset; + for (int i = 0; i < nCurrentWindow; i++) { + G.setColumn(i, Arrays.copyOfRange(xSeries, currentHead + i, currentHead + i + window)); + } + SingularValueDecomposition svdG = new SingularValueDecomposition(G); + RealMatrix Q = svdG.getU(); + + // find the largest singular value for the r principal components + RealMatrix UTQ = UT.getSubMatrix(0, r - 1, 0, window - 1).multiply(Q.getSubMatrix(0, window - 1, 0, r - 1)); + SingularValueDecomposition svdUTQ = new SingularValueDecomposition(UTQ); + double[] s = svdUTQ.getSingularValues(); + + return 1.d - s[0]; + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3ebd771e/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java b/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java new file mode 100644 index 0000000..3ab5ae8 --- /dev/null +++ b/core/src/main/java/hivemall/anomaly/SSTChangePointUDF.java @@ -0,0 +1,197 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.anomaly; + +import hivemall.UDFWithOptions; +import hivemall.utils.collections.DoubleRingBuffer; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.lang.Primitives; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.BooleanWritable; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; + +@Description( + name = "sst_changepoint", + value = "_FUNC_(double|array<double> x [, const string options])" + + " - Returns change-point scores and decisions using Singular Spectrum Transformation (SST)." + + " It will return a tuple <double changepoint_score [, boolean is_changepoint]>") +public final class SSTChangePointUDF extends UDFWithOptions { + + private transient Parameters _params; + private transient SSTChangePoint _sst; + + private transient double[] _scores; + private transient Object[] _result; + private transient DoubleWritable _changepointScore; + @Nullable + private transient BooleanWritable _isChangepoint = null; + + public SSTChangePointUDF() {} + + // Visible for testing + Parameters getParameters() { + return _params; + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("w", "window", true, "Number of samples which affects change-point score [default: 30]"); + opts.addOption("n", "n_past", true, + "Number of past windows for change-point scoring [default: equal to `w` = 30]"); + opts.addOption("m", "n_current", true, + "Number of current windows for change-point scoring [default: equal to `w` = 30]"); + opts.addOption("g", "current_offset", true, + "Offset of the current windows from the updating sample [default: `-w` = -30]"); + opts.addOption("r", "n_component", true, + "Number of singular vectors (i.e. principal components) [default: 3]"); + opts.addOption("k", "n_dim", true, + "Number of dimensions for the Krylov subspaces [default: 5 (`2*r` if `r` is even, `2*r-1` otherwise)]"); + opts.addOption("th", "threshold", true, + "Score threshold (inclusive) for determining change-point existence [default: -1, do not output decision]"); + return opts; + } + + @Override + protected CommandLine processOptions(String optionValues) throws UDFArgumentException { + CommandLine cl = parseOptions(optionValues); + + this._params.w = Primitives.parseInt(cl.getOptionValue("w"), _params.w); + this._params.n = Primitives.parseInt(cl.getOptionValue("n"), _params.w); + this._params.m = Primitives.parseInt(cl.getOptionValue("m"), _params.w); + this._params.g = Primitives.parseInt(cl.getOptionValue("g"), -1 * _params.w); + this._params.r = Primitives.parseInt(cl.getOptionValue("r"), _params.r); + this._params.k = Primitives.parseInt( + cl.getOptionValue("k"), (_params.r % 2 == 0) ? (2 * _params.r) : (2 * _params.r - 1)); + this._params.changepointThreshold = Primitives.parseDouble( + cl.getOptionValue("th"), _params.changepointThreshold); + + Preconditions.checkArgument(_params.w >= 2, "w must be greather than 1: " + _params.w); + Preconditions.checkArgument(_params.r >= 1, "r must be greater than 0: " + _params.r); + Preconditions.checkArgument(_params.k >= 1, "k must be greater than 0: " + _params.k); + Preconditions.checkArgument(_params.changepointThreshold > 0.d && _params.changepointThreshold < 1.d, + "changepointThreshold must be in range (0, 1): " + _params.changepointThreshold); + + return cl; + } + + @Override + public ObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) + throws UDFArgumentException { + if (argOIs.length < 1 || argOIs.length > 2) { + throw new UDFArgumentException( + "_FUNC_(double|array<double> x [, const string options]) takes 1 or 2 arguments: " + + Arrays.toString(argOIs)); + } + + this._params = new Parameters(); + if (argOIs.length == 2) { + String options = HiveUtils.getConstString(argOIs[1]); + processOptions(options); + } + + ObjectInspector argOI0 = argOIs[0]; + PrimitiveObjectInspector xOI = HiveUtils.asDoubleCompatibleOI(argOI0); + this._sst = new SSTChangePoint(_params, xOI); + + this._scores = new double[1]; + + final Object[] result; + final ArrayList<String> fieldNames = new ArrayList<String>(); + final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldNames.add("changepoint_score"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + if (_params.changepointThreshold != -1d) { + fieldNames.add("is_changepoint"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableBooleanObjectInspector); + result = new Object[2]; + this._isChangepoint = new BooleanWritable(false); + result[1] = _isChangepoint; + } else { + result = new Object[1]; + } + this._changepointScore = new DoubleWritable(0.d); + result[0] = _changepointScore; + this._result = result; + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @Override + public Object[] evaluate(@Nonnull DeferredObject[] args) throws HiveException { + Object x = args[0].get(); + if (x == null) { + return _result; + } + + _sst.update(x, _scores); + + double changepointScore = _scores[0]; + _changepointScore.set(changepointScore); + if (_isChangepoint != null) { + _isChangepoint.set(changepointScore >= _params.changepointThreshold); + } + + return _result; + } + + @Override + public void close() throws IOException { + this._result = null; + this._changepointScore = null; + this._isChangepoint = null; + } + + @Override + public String getDisplayString(String[] children) { + return "sst(" + Arrays.toString(children) + ")"; + } + + static final class Parameters { + int w = 30; + int n = 30; + int m = 30; + int g = -30; + int r = 3; + int k = 5; + double changepointThreshold = -1.d; + + Parameters() {} + } + + public interface SSTChangePointInterface { + void update(@Nonnull Object arg, @Nonnull double[] outScores) throws HiveException; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/3ebd771e/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java b/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java new file mode 100644 index 0000000..b41d474 --- /dev/null +++ b/core/src/test/java/hivemall/anomaly/SSTChangePointTest.java @@ -0,0 +1,111 @@ +/* + * Hivemall: Hive scalable Machine Learning Library + * + * Copyright (C) 2015 Makoto YUI + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package hivemall.anomaly; + +import hivemall.anomaly.SSTChangePointUDF.Parameters; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.zip.GZIPInputStream; + +import javax.annotation.Nonnull; + +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.junit.Assert; +import org.junit.Test; + +public class SSTChangePointTest { + private static final boolean DEBUG = false; + + @Test + public void testSST() throws IOException, HiveException { + Parameters params = new Parameters(); + PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector; + SSTChangePoint sst = new SSTChangePoint(params, oi); + double[] outScores = new double[1]; + + BufferedReader reader = readFile("cf1d.csv"); + println("x change"); + String line; + int numChangepoints = 0; + while ((line = reader.readLine()) != null) { + double x = Double.parseDouble(line); + sst.update(x, outScores); + printf("%f %f%n", x, outScores[0]); + if (outScores[0] > 0.95d) { + numChangepoints++; + } + } + Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints, + numChangepoints > 0); + Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints, + numChangepoints < 5); + } + + @Test + public void testTwitterData() throws IOException, HiveException { + Parameters params = new Parameters(); + PrimitiveObjectInspector oi = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector; + SSTChangePoint sst = new SSTChangePoint(params, oi); + double[] outScores = new double[1]; + + BufferedReader reader = readFile("twitter.csv.gz"); + println("# time x change"); + String line; + int i = 1, numChangepoints = 0; + while ((line = reader.readLine()) != null) { + double x = Double.parseDouble(line); + sst.update(x, outScores); + printf("%d %f %f%n", i, x, outScores[0]); + if (outScores[0] > 0.005d) { + numChangepoints++; + } + i++; + } + Assert.assertTrue("#changepoints SHOULD be greater than 0: " + numChangepoints, + numChangepoints > 0); + Assert.assertTrue("#changepoints SHOULD be less than 5: " + numChangepoints, + numChangepoints < 5); + } + + private static void println(String msg) { + if (DEBUG) { + System.out.println(msg); + } + } + + private static void printf(String format, Object... args) { + if (DEBUG) { + System.out.printf(format, args); + } + } + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + InputStream is = SSTChangePointTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } + +}