This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 7c4f3455958e660669c7cb42cb7f3884314c71fa Author: Matthias Boehm <[email protected]> AuthorDate: Mon Feb 6 19:03:39 2023 +0100 [SYSTEMDS-3496] New auc() builtin function for the area under ROC curves This patch introduces the new auc() builtin function that takes a response vector Y and probabilities P (e.g., from multiLogRegPredict) and computes the area under the Receiver-Operating-Characteristic curve. The current implementation naively computes the distinct probabilities and then evaluates the true and false positive rates for all these possible thresholds, with semantics equivalent to the R pROC package. Next steps include fixes for compiling unique operations (currently requires forced single node), missing unique spark operations, and a more efficient vectorized auc() implementation via cumsum (and unique extensions to obtain the last indexes of unique values). --- scripts/builtin/auc.dml | 72 +++++++++++ .../java/org/apache/sysds/common/Builtins.java | 1 + .../functions/builtin/part1/BuiltinAucTest.java | 133 +++++++++++++++++++++ src/test/scripts/functions/builtin/auc.dml | 25 ++++ 4 files changed, 231 insertions(+) diff --git a/scripts/builtin/auc.dml b/scripts/builtin/auc.dml new file mode 100644 index 0000000000..1084b62eba --- /dev/null +++ b/scripts/builtin/auc.dml @@ -0,0 +1,72 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# This builting function computes the area under the ROC curve (AUC) +# for binary classifiers. +# +# INPUT: +# ------------------------------------------------------------------------------ +# Y Binary response vector (shape: n x 1), in -1/+1 or 0/1 encoding +# P Prediction scores (predictor such as estimated probabilities) +# for true class (shape: n x 1), assumed in [0,1] +# ------------------------------------------------------------------------------ +# +# OUTPUT: +# ------------------------------------------------------------------------------ +# auc Area under the ROC curve (AUC) +# ------------------------------------------------------------------------------ + +m_auc = function(Matrix[Double] Y, Matrix[Double] P) + return(Double auc) +{ + minv = min(Y) + maxv = max(Y) + + # check input parameter assertions + if(minv == maxv) + stop("AUC: stopping because only one class label existing in Y") + if(sum(Y==minv) + sum(Y==maxv) < nrow(Y)) + stop("AUC: stopping because more than two class labels existing in Y") + + # convert -1/1 to 0/1 if necessary + if( minv < 0 ) + Y = (Y+1) != 0; + pos = sum(Y); + neg = nrow(Y) - pos; + + # compute ROC curve for distinct threshold scores + # (cut-offs > and <= choosen to match R-pROC-package behavior) + # TODO vectorize via ordering + cumsum (but indexes of unique missing) + dP = order(target=unique(P)); # distinct P thresholds, increasing + nd = nrow(dP) + tp = matrix(0, nd, 1); + fp = matrix(0, nd, 1); + parfor(i in 1:nd) { + tp[i] = sum(P>dP[i] & Y) + fp[i] = sum(P<=dP[i] & !Y) + } + tpr = tp / pos; # true positive rate, increasing + fpr = fp / neg; # false postive rate, increasing + + # compute AUC via Trapezoidal rule + auc = as.scalar(tpr[1] * fpr[1]) + + sum((fpr[2:nd]-fpr[1:(nd-1)]) * (tpr[2:nd]+tpr[1:(nd-1)])/2); +} diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 24215ffdab..f7cbb972df 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -51,6 +51,7 @@ public enum Builtins { ARIMA("arima", true), ASIN("asin", false), ATAN("atan", false), + AUC("auc", true), AUTOENCODER2LAYER("autoencoder_2layer", true), AVG_POOL("avg_pool", false), AVG_POOL_BACKWARD("avg_pool_backward", false), diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java new file mode 100644 index 0000000000..49502e8371 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAucTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin.part1; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; + +public class BuiltinAucTest extends AutomatedTestBase +{ + private final static String TEST_NAME = "auc"; + private final static String TEST_DIR = "functions/builtin/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinAucTest.class.getSimpleName() + "/"; + + private double eps = 0.01; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"})); + } + + //FIXME missing spark instruction unique + + @Test + public void testPerfectSeparationOrdered() { + runAucTest(1.0, new double[]{0,0,0,1,1,1}, + new double[]{0.1,0.2,0.3,0.4,0.55,0.56}); + } + + @Test + public void testPerfectSeparationUnordered() { + runAucTest(1.0, new double[]{0,1,0,1,0,1}, + new double[]{0.1,0.5,0.2,0.55,0.3,0.56}); + } + + @Test + public void testPerfectSeparationUnorderedDups() { + runAucTest(1.0, new double[]{0,1,0,1,0,1,0,1,0,1,0,1}, + new double[]{0.1,0.5,0.2,0.55,0.3,0.56,0.1,0.5,0.2,0.55,0.3,0.56}); + } + + //selected cases, double checked with R pROC (but not explicitly compared to avoid dependency) + + @Test + public void testMisc1() { + runAucTest(0.8899, new double[]{0,0,1,0,1,1}, + new double[]{0.1,0.2,0.3,0.4,0.5,0.55}); + } + + @Test + public void testMisc2() { + runAucTest(0.8899, new double[]{-1,-1,1,-1,1,1}, + new double[]{0.1,0.2,0.3,0.4,0.5,0.55}); + } + + @Test + public void testMisc3() { + runAucTest(0.75, new double[]{0,0,1,0,1,1,0,1}, + new double[]{0.1,0.2,0.2,0.21,0.7,0.7,0.7,0.7}); + } + + @Test + public void testMisc4() { + runAucTest(0.6, new double[]{0,0,1,0,1,1,0,1,0}, + new double[]{0.1,0.2,0.2,0.21,0.7,0.7,0.7,0.7,0.9}); + } + + @Test + public void testMisc5() { + runAucTest(0.6, new double[]{0,0,0,1,0,1,1,0,1}, + new double[]{0.9,0.1,0.2,0.2,0.21,0.7,0.7,0.7,0.7}); + } + + @Test + public void testMisc6() { + runAucTest(0.5, new double[]{0,0,1,0,1,1,0,1,0,0}, + new double[]{0.1,0.2,0.2,0.21,0.7,0.7,0.7,0.7,0.9,0.9}); + } + + @Test + public void testMisc7() { + runAucTest(0.4286, new double[]{0,0,1,0,1,1,0,1,0,0,0}, + new double[]{0.1,0.2,0.2,0.21,0.7,0.7,0.7,0.7,0.9,0.9,0.99}); + } + + private void runAucTest(double auc, double[] Y, double[] P) + { + ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE); + + try + { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{"-args", input("Yt"), input("Pt"), output("C") }; + + //generate actual dataset + writeInputMatrixWithMTD("Yt", new double[][]{Y}, false); + writeInputMatrixWithMTD("Pt", new double[][]{P}, false); + + //execute test + runTest(true, false, null, -1); + + //compare matrices + double val = readDMLMatrixFromOutputDir("C").get(new CellIndex(1,1)); + Assert.assertEquals("Incorrect values: ", auc, val, eps); + } + finally { + rtplatform = platformOld; + } + } +} diff --git a/src/test/scripts/functions/builtin/auc.dml b/src/test/scripts/functions/builtin/auc.dml new file mode 100644 index 0000000000..065546f45d --- /dev/null +++ b/src/test/scripts/functions/builtin/auc.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +Y = read($1); +P = read($2); +C = as.matrix(auc(t(Y), t(P))); +write(C, $3)
