Baunsgaard commented on a change in pull request #1189: URL: https://github.com/apache/systemds/pull/1189#discussion_r584967658
########## File path: src/test/java/org/apache/sysds/test/functions/builtin/BuiltinCoxTest.java ########## @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin; + +import org.apache.sysds.common.Types; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.junit.Test; + +public class BuiltinCoxTest extends AutomatedTestBase +{ + private final static String TEST_NAME = "cox"; + private final static String TEST_DIR = "functions/builtin/"; + private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinCoxTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"})); + } + + @Test + public void testFunction() { + runCoxTest(50, 2.0, 1.5, 0.8, 100, 0.05, 1.0,0.000001, 100, 0); + } + + public void runCoxTest(int numRecords, double scaleWeibull, double shapeWeibull, double prob, + int numFeatures, double sparsity, double alpha, double tol, int moi, int mii) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + programArgs = new String[]{ + "-nvargs", "M=" + output("M"), "S=" + output("S"), "T=" + output("T"), "COV=" + output("COV"), + "RT=" + output("RT"), "XO=" + output("XO"), "n=" + numRecords, "l=" + scaleWeibull, + "v=" + shapeWeibull, "p=" + prob, "m=" + numFeatures, "sp=" + sparsity, + "alpha=" + alpha, "tol=" + tol, "moi=" + moi, "mii=" + mii}; + + runTest(true, false, null, -1); Review comment: This is not sufficient testing, since this only runs the script. but does not verify that the output is correct. I would suggest making as simple inputs as possible where you know a certain output is produced, and then verify these. ########## File path: src/test/java/org/apache/sysds/test/functions/builtin/BuiltinKmTest.java ########## @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin; + +import org.apache.sysds.common.Types; +import org.apache.sysds.lops.LopProperties; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class BuiltinKmTest extends AutomatedTestBase +{ + private final static String TEST_NAME = "km"; + private final static String TEST_DIR = "functions/builtin/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinKmTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"C"})); + } + + @Test + public void testKmDefaultConfiguration() { + runKmTest(50, 2.0, 1.5, 0.8, 2, + 1, 10, 0.05,"greenwood", "log", "none"); + } + @Test + public void testKmErrTypePeto() { + runKmTest(50, 2.0, 1.5, 0.8, 2, + 1, 10, 0.05,"peto", "log", "none"); + } + @Test + public void testKmConfTypePlain() { + runKmTest(50, 2.0, 1.5, 0.8, 2, + 1, 10, 0.05,"greenwood", "plain", "none"); + } + @Test + public void testKmConfTypeLogLog() { + runKmTest(50, 2.0, 1.5, 0.8, 2, + 1, 10, 0.05,"greenwood", "log-log", "none"); + } + + private void runKmTest(int numRecords, double scaleWeibull, double shapeWeibull, double prob, + int numCatFeaturesGroup, int numCatFeaturesStrat, int maxNumLevels, double alpha, String err_type, + String conf_type, String test_type) { + Types.ExecMode platformOld = setExecMode(LopProperties.ExecType.SPARK); + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + programArgs = new String[]{ + "-nvargs", "O=" + output("O"), "M=" + output("M"), "T=" + output("T"), + "T_GROUPS_OE=" + output("T_GROUPS_OE"), "n=" + numRecords, "l=" + scaleWeibull, + "v=" + shapeWeibull, "p=" + prob, "g=" + numCatFeaturesGroup, "s=" + numCatFeaturesStrat, + "f=" + maxNumLevels, "alpha=" + alpha, "err_type=" + err_type, + "conf_type=" + conf_type, "test_type=" + test_type}; + + runTest(true, false, null, -1); Review comment: same testing problem here. ########## File path: src/test/scripts/functions/builtin/cox.dml ########## @@ -0,0 +1,69 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +num_records = $n; +lambda = $l; +p_event = $p; +# parameters related to the cox model +num_features = $m; +sparsity = $sp; +p_censor = 1 - p_event; # prob. that record is censored + +v = $v; +# generate feature matrix +X_t = rand (rows = num_records, cols = num_features, min = 1, max = 5, pdf = "uniform", sparsity = sparsity); Review comment: give it a seed to make it not random, or generate a known matrix inside your tests and then read this from disk. Some of the other builtin functions test could inspire. ########## File path: scripts/builtin/km.dml ########## @@ -0,0 +1,656 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# +# Builtin function that implements the analysis of survival data with KAPLAN-MEIER estimates +# +# INPUT PARAMETERS: +# --------------------------------------------------------------------------------------------- +# NAME TYPE DEFAULT MEANING +# --------------------------------------------------------------------------------------------- +# X String --- Location to read the input matrix X containing the survival data: +# timestamps, whether event occurred (1) or data is censored (0), and a number of factors (categorical features) +# for grouping and/or stratifying Review comment: our coding conventions try to use spaces instead of tabs in dml scripts. this is why it looks strange on GitHub, because it is not consistent. Try to replace all tabs with spaces. In the documentation up here, make the new lines spaces allign with the above line. ########## File path: scripts/builtin/km.dml ########## @@ -0,0 +1,656 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# +# Builtin function that implements the analysis of survival data with KAPLAN-MEIER estimates +# +# INPUT PARAMETERS: +# --------------------------------------------------------------------------------------------- +# NAME TYPE DEFAULT MEANING +# --------------------------------------------------------------------------------------------- +# X String --- Location to read the input matrix X containing the survival data: +# timestamps, whether event occurred (1) or data is censored (0), and a number of factors (categorical features) +# for grouping and/or stratifying +# TE String --- Column indices of X which contain timestamps (first entry) and event information (second entry) +# GI String --- Column indices of X corresponding to the factors to be used for grouping +# SI String --- Column indices of X corresponding to the factors to be used for stratifying +# O String --- Location to write the matrix containing the results of the Kaplan-Meier analysis; see below for the description +# M String --- Location to write Matrix M containing the following statistic: total number of events, median and its confidence intervals; +# if survival data for multiple groups and strata are provided each row of M contains the above statistics per group and stratum +# T String " " If survival data from multiple groups available and ttype=log-rank or wilcoxon, +# location to write the matrix containing result of the (stratified) test for comparing multiple groups +# alpha Double 0.05 Parameter to compute 100*(1-alpha)% confidence intervals for the survivor function and its median +# err_type String "greenwood" Parameter to specify the error type according to "greenwood" (the default) or "peto" +# conf_type String "log" Parameter to modify the confidence interval; "plain" keeps the lower and upper bound of +# the confidence interval unmodified, "log" (the default) corresponds to logistic transformation and +# "log-log" corresponds to the complementary log-log transformation +# test_type String "none" If survival data for multiple groups is available specifies which test to perform for comparing +# survival data across multiple groups: "none" (the default) "log-rank" or "wilcoxon" test +# fmtO String "text" The output format of results of the Kaplan-Meier analysis, such as "text" or "csv" +# --------------------------------------------------------------------------------------------- +# OUTPUT: +# 1- Matrix KM whose dimension depends on the number of groups (denoted by g) and strata (denoted by s) in the data: +# each collection of 7 consecutive columns in KM corresponds to a unique combination of groups and strata in the data with the following schema +# 1. col: timestamp +# 2. col: no. at risk +# 3. col: no. of events +# 4. col: Kaplan-Meier estimate of survivor function surv +# 5. col: standard error of surv +# 6. col: lower 100*(1-alpha)% confidence interval for surv +# 7. col: upper 100*(1-alpha)% confidence interval for surv +# 2- Matrix M whose dimension depends on the number of groups (g) and strata (s) in the data (k denotes the number of factors used for grouping +# ,i.e., ncol(GI) and l denotes the number of factors used for stratifying, i.e., ncol(SI)) +# M[,1:k]: unique combination of values in the k factors used for grouping +# M[,(k+1):(k+l)]: unique combination of values in the l factors used for stratifying +# M[,k+l+1]: total number of records +# M[,k+l+2]: total number of events +# M[,k+l+3]: median of surv +# M[,k+l+4]: lower 100*(1-alpha)% confidence interval of the median of surv +# M[,k+l+5]: upper 100*(1-alpha)% confidence interval of the median of surv +# If the number of groups and strata is equal to 1, M will have 4 columns with +# M[,1]: total number of events +# M[,2]: median of surv +# M[,3]: lower 100*(1-alpha)% confidence interval of the median of surv +# M[,4]: upper 100*(1-alpha)% confidence interval of the median of surv +# 3- If survival data from multiple groups available and ttype=log-rank or wilcoxon, a 1 x 4 matrix T and an g x 5 matrix T_GROUPS_OE with +# T_GROUPS_OE[,1] = no. of events +# T_GROUPS_OE[,2] = observed value (O) +# T_GROUPS_OE[,3] = expected value (E) +# T_GROUPS_OE[,4] = (O-E)^2/E +# T_GROUPS_OE[,5] = (O-E)^2/V +# T[1,1] = no. of groups +# T[1,2] = degree of freedom for Chi-squared distributed test statistic +# T[1,3] = test statistic +# T[1,4] = P-value +# ------------------------------------------------------------------------------------------- + +m_km = function(Matrix[Double] X, Matrix[Double] TE, Matrix[Double] GI, Matrix[Double] SI, + Double alpha = 0.05, String err_type = "greenwood", + String conf_type = "log", String test_type = "none") + return (Matrix[Double] O, Matrix[Double] M, + Matrix[Double] T, Matrix[Double] T_GROUPS_OE) { + + +if (ncol(GI) != 0 & nrow(GI) != 0) { Review comment: indent the content of the function. ########## File path: src/test/scripts/functions/builtin/cox.dml ########## @@ -0,0 +1,69 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +num_records = $n; +lambda = $l; +p_event = $p; +# parameters related to the cox model +num_features = $m; +sparsity = $sp; +p_censor = 1 - p_event; # prob. that record is censored + +v = $v; +# generate feature matrix +X_t = rand (rows = num_records, cols = num_features, min = 1, max = 5, pdf = "uniform", sparsity = sparsity); +# generate coefficients +B = rand (rows = num_features, cols = 1, min = -1.0, max = 1.0, pdf = "uniform", sparsity = 1.0); + +# generate timestamps +U = rand (rows = num_records, cols = 1, min = 0.000000001, max = 1); +T = (-log (U) / (lambda * exp (X_t %*% B)) ) ^ (1/v); + +Y = matrix (0, rows = num_records, cols = 2); +event = floor (rand (rows = num_records, cols = 1, min = (1 - p_censor), max = (1 + p_event))); +n_time = sum (event); +Y[,2] = event; + +# binning of event times +min_T = min (T); +max_T = max (T); +# T = T - min_T; +len = max_T - min_T; +num_bins = len / n_time; +T = ceil (T / num_bins); + +# print ("min(T) " + min(T) + " max(T) " + max(T)); +Y[,1] = T; + +X = cbind (Y, X_t); + +TE = matrix ("1 2", rows = 2, cols = 1); +F = seq (1, num_features); +R = matrix (0, rows = 1, cols = 1); + +[M, S, T, COV, RT, XO] = cox(X, TE, F, R, $alpha, $tol, $moi, $mii); + +write(M, $M); +write(S, $S); +write(T, $T); +write(COV, $COV); +write(RT, $RT); +write(XO, $XO); Review comment: good that you write out the results, now use them to verify. ########## File path: scripts/builtin/cox.dml ########## @@ -0,0 +1,478 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# +# THIS SCRIPT FITS A COX PROPORTIONAL HAZARD REGRESSION MODEL. +# The Breslow method is used for handling ties and the regression parameters +# are computed using trust region newton method with conjugate gradient +# +# INPUT PARAMETERS: +# --------------------------------------------------------------------------------------------- +# NAME TYPE DEFAULT MEANING +# --------------------------------------------------------------------------------------------- +# X String --- Location to read the input matrix X containing the survival data containing the following information +# - 1: timestamps +# - 2: whether an event occurred (1) or data is censored (0) +# - 3: feature vectors +# TE String --- Column indices of X as a column vector which contain timestamp (first row) and event information (second row) +# F String " " Column indices of X as a column vector which are to be used for fitting the Cox model +# R String " " If factors (categorical variables) are available in the input matrix X, location to read matrix R containing +# the start and end indices of the factors in X +# - R[,1]: start indices +# - R[,2]: end indices +# Alternatively, user can specify the indices of the baseline level of each factor which needs to be removed from X; +# in this case the start and end indices corresponding to the baseline level need to be the same; +# if R is not provided by default all variables are considered to be continuous +# M String --- Location to store the results of Cox regression analysis including estimated regression parameters of the fitted +# Cox model (the betas), their standard errors, confidence intervals, and P-values +# S String " " Location to store a summary of some statistics of the fitted cox proportional hazard model including +# no. of records, no. of events, log-likelihood, AIC, Rsquare (Cox & Snell), and max possible Rsquare; +# by default is standard output +# T String " " Location to store the results of Likelihood ratio test, Wald test, and Score (log-rank) test of the fitted model; +# by default is standard output +# COV String --- Location to store the variance-covariance matrix of the betas +# RT String --- Location to store matrix RT containing the order-preserving recoded timestamps from X +# XO String --- Location to store sorted input matrix by the timestamps +# MF String --- Location to store column indices of X excluding the baseline factors if available +# alpha Double 0.05 Parameter to compute a 100*(1-alpha)% confidence interval for the betas +# tol Double 0.000001 Tolerance ("epsilon") +# moi Int 100 Max. number of outer (Newton) iterations +# mii Int 0 Max. number of inner (conjugate gradient) iterations, 0 = no max +# fmt String "text" Matrix output format, usually "text" or "csv" (for matrices only) +# --------------------------------------------------------------------------------------------- +# OUTPUT: +# 1- A D x 7 matrix M, where D denotes the number of covariates, with the following schema: +# M[,1]: betas +# M[,2]: exp(betas) +# M[,3]: standard error of betas +# M[,4]: Z +# M[,5]: P-value +# M[,6]: lower 100*(1-alpha)% confidence interval of betas +# M[,7]: upper 100*(1-alpha)% confidence interval of betas +# +# Two matrices containing a summary of some statistics of the fitted model: +# 1- File S with the following format +# - row 1: no. of observations +# - row 2: no. of events +# - row 3: log-likelihood +# - row 4: AIC +# - row 5: Rsquare (Cox & Snell) +# - row 6: max possible Rsquare +# 2- File T with the following format +# - row 1: Likelihood ratio test statistic, degree of freedom, P-value +# - row 2: Wald test statistic, degree of freedom, P-value +# - row 3: Score (log-rank) test statistic, degree of freedom, P-value +# +# Additionally, the following matrices are stored (needed for prediction) +# 1- A column matrix RT that contains the order-preserving recoded timestamps from X +# 2- Matrix XO which is matrix X with sorted timestamps +# 3- Variance-covariance matrix of the betas COV +# 4- A column matrix MF that contains the column indices of X with the baseline factors removed (if available) +# ------------------------------------------------------------------------------------------- +m_cox = function(Matrix[Double] X, Matrix[Double] TE, Matrix[Double] F, Matrix[Double] R, + Double alpha = 0.05, Double tol = 0.000001, Int moi = 100, Int mii = 0) + return (Matrix[Double] M, Matrix[Double] S, Matrix[Double] T, + Matrix[Double] COV, Matrix[Double] RT, Matrix[Double] XO) { + +X_orig = X; Review comment: here also indent the content of the function ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
