This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new b651db516d [SYSTEMDS-3669] Builtin for computation of shapley values
b651db516d is described below
commit b651db516da31222018396bb996b3d825766c7da
Author: louislepage <[email protected]>
AuthorDate: Sun Oct 27 17:22:35 2024 +0100
[SYSTEMDS-3669] Builtin for computation of shapley values
Closes #1946.
---
scripts/builtin/shapExplainer.dml | 732 +++++++++++++++++++++
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../builtin/part2/BuiltinShapExplainerTest.java | 156 +++++
.../functions/builtin/shapExplainerComponent.dml | 57 ++
.../functions/builtin/shapExplainerUnit.dml | 104 +++
5 files changed, 1050 insertions(+)
diff --git a/scripts/builtin/shapExplainer.dml
b/scripts/builtin/shapExplainer.dml
new file mode 100644
index 0000000000..626dc7da4c
--- /dev/null
+++ b/scripts/builtin/shapExplainer.dml
@@ -0,0 +1,732 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Computes shapley values for multiple instances in parallel using antithetic
permutation sampling.
+# The resulting matrix phis holds the shapley values for each feature in the
column given by the index of the feature in the sample.
+#
+# This method first creates two large matrices for masks and masked background
data for all permutations and
+# then runs in paralell on all instances in x.
+# While the prepared matrices can become very large (2 * #features *
#permuations * #n_samples * #features),
+# the preparation of a row for the model call breaks down to a single
element-wise multiplication of this mask with the row and
+# an addition to the masked background data, since masks can be reused for
each instance.
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# model_function The function of the model to be evaluated as a String. This
function has to take a matrix of samples
+# and return a vector of predictions.
+# It might be usefull to wrap the model into a function the
takes and returns the desired shapes and
+# use this wrapper here.
+# model_args Arguments in order for the model, if desired. This will be
prepended by the created instances-matrix.
+# x_instances Multiple instances as rows for which to compute the shapley
values.
+# X_bg The background dataset from which to pull the random samples
to perform Monte Carlo integration.
+# n_permutations The number of permutaions. Defaults to 10. Theoretical 1
should already be enough for models with up
+# to second order interaction effects.
+# n_samples Number of samples from X_bg used for marginalization.
+# remove_non_var EXPERIMENTAL: If set, for every instance the varaince of
each feature is checked against this feature in the
+# background data. If it does not change, we do not run any
model cals for it.
+# seed A seed, in case the sampling has to be deterministic.
+# verbose A boolean to enable logging of each step of the function.
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# S Matrix holding the shapley values along the cols, one row per
instance.
+# expected Double holding the average prediction of all instances.
+# -----------------------------------------------------------------------------
+s_shapExplainer = function(String model_function, list[unknown] model_args,
Matrix[Double] x_instances,
+ Matrix[Double] X_bg, Integer n_permutations = 10, Integer n_samples = 100,
Integer remove_non_var=0,
+ Matrix[Double] partitions=as.matrix(-1), Integer seed = -1, Integer
verbose = 0)
+ return (Matrix[Double] row_phis, Double expected)
+{
+ u_printShapMessage("Parallel Permutation Explainer for "+nrow(x_instances)+"
rows.", verbose)
+ u_printShapMessage("Number of Features: "+ncol(x_instances), verbose )
+ total_preds=ncol(x_instances)*2*n_permutations*n_samples*nrow(x_instances)
+ u_printShapMessage("Number of predictions: "+toString(total_preds)+" in
"+nrow(x_instances)+
+ " parallel cals.", verbose )
+
+ #start with all features
+ features=u_range(1, ncol(x_instances))
+
+ #handle partitions
+ if(sum(partitions) != -1){
+ if(remove_non_var != 0){
+ stop("shapley_permutations_by_row:ERROR: Can't use n_non_varying_inds
and partitions at the same time.")
+ }
+ features=removePartitionsFromFeatures(features, partitions)
+
reduced_total_preds=ncol(features)*2*n_permutations*n_samples*nrow(x_instances)
+ u_printShapMessage("Using Partitions reduces number of features to
"+ncol(features)+".", verbose )
+ u_printShapMessage("Total number of predictions reduced by
"+(total_preds-reduced_total_preds)/total_preds+" to "+reduced_total_preds+".",
verbose )
+ }
+
+ #lengths and offsets
+ total_features = ncol(x_instances)
+ perm_length = ncol(features)
+ full_mask_offset = perm_length * 2 * n_samples
+ n_partition_features = total_features - perm_length
+
+ #sample from X_bg
+ u_printShapMessage("Sampling from X_bg", verbose )
+ # could use new samples for each permutation by sampling
n_samples*n_permutations
+ X_bg_samples = u_sample_with_potential_replace(X_bg=X_bg, samples=n_samples,
seed=seed )
+ row_phis = matrix(0, rows=nrow(x_instances), cols=total_features)
+ expected_m = matrix(0, rows=nrow(x_instances), cols=1)
+
+ #prepare masks for all permutations, since it stays the same for every row
+ u_printShapMessage("Preparing reusable intermediate masks.", verbose )
+ permutations = matrix(0, rows=n_permutations, cols=perm_length)
+ masks_for_permutations = matrix(0,
rows=perm_length*2*n_permutations*n_samples, cols=total_features)
+
+ parfor (i in 1:n_permutations, check=0){
+ #shuffle features to get permutation
+ permutations[i] = t(u_shuffle(t(features)))
+ perm_mask = prepare_mask_for_permutation(permutation=permutations[i],
partitions=partitions)
+
+ offset_masks = (i-1) * full_mask_offset + 1
+
masks_for_permutations[offset_masks:offset_masks+full_mask_offset-1]=prepare_full_mask(perm_mask,
n_samples)
+ }
+
+ #replicate background and mask it, since it also can stay the same for every
row
+ # could use new samples for each permutation by sampling
n_samples*n_permutations and telling this function about it
+ masked_bg_for_permutations = prepare_masked_X_bg(masks_for_permutations,
X_bg_samples, 0)
+ u_printShapMessage("Computing phis in parallel.", verbose )
+
+ #enable spark execution for parfor if desired
+ #TODO allow spark mode via parameter?
+ #parfor (i in 1:nrow(x_instances), opt=CONSTRAINED, mode=REMOTE_SPARK){
+
+ parfor (i in 1:nrow(x_instances)){
+ if(remove_non_var == 1){
+ # try to remove inds that do not vary from the background
+ non_var_inds = get_non_varying_inds(x_instances[i], X_bg_samples)
+ # only remove if more than 2 features remain, less then two breaks
removal procedure
+ if (ncol(x_instances) > length(non_var_inds)+2){
+ #remove samples and masks for non varying features
+ [i_masks_for_permutations, i_masked_bg_for_permutations] =
remove_inds(masks_for_permutations, masked_bg_for_permutations, permutations,
non_var_inds, n_samples)
+ }else{
+ # we would remove all but two features, whichs breaks the removal
algorithm
+ non_var_inds = as.matrix(-1)
+ i_masks_for_permutations = masks_for_permutations
+ i_masked_bg_for_permutations = masked_bg_for_permutations
+ }
+ } else {
+ non_var_inds = as.matrix(-1)
+ i_masks_for_permutations = masks_for_permutations
+ i_masked_bg_for_permutations = masked_bg_for_permutations
+ }
+
+ #apply masks and bg data for all permutations at once
+ X_test = apply_full_mask(x_instances[i], i_masks_for_permutations,
i_masked_bg_for_permutations)
+
+ #generate args for call to model
+ X_arg = append(list(X=X_test), model_args)
+
+ #call model
+ P = eval(model_function, X_arg)
+
+ #compute means, deviding n_rows by n_samples
+ P = compute_means_from_predictions(P=P, n_samples=n_samples)
+
+ #compute phis
+ [phis, e] = compute_phis_from_prediction_means(P=P,
permutations=permutations, non_var_inds=non_var_inds,
n_partition_features=n_partition_features)
+ expected_m[i] = e
+
+ #compute phis for this row from all permutations
+ row_phis[i] = t(phis)
+ }
+ #compute expected of model from all rows
+ expected = mean(expected_m)
+}
+
+# Computes which indices do not vary from the background.
+# Uses the appraoch from numpy.isclose() and compares to the largest diff of
each feature in the bg data.
+# In the futere, more advanced techniques like using std-dev of bg data as a
tollerance could be used.
+#
+# INPUT:
+# -----------------------------------------------------------------------------
+# x One single instance.
+# X_bg Background dataset.
+# -----------------------------------------------------------------------------
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# non_varying_inds A row-vector with all the indices that do not vary from the
background dataset.
+# -----------------------------------------------------------------------------
+get_non_varying_inds = function(Matrix[Double] x, Matrix[Double] X_bg)
+return (Matrix[Double] non_varying_inds){
+ #from numpy.isclose but adapted to fit MSE of shap, which is within the same
scale
+ rtol = 1e-04
+ atol = 1e-05
+
+ # compute distance metrics
+ diff = colMaxs(abs(X_bg -x))
+ rdist = atol + rtol * colMaxs(abs(X_bg))
+
+ non_varying_inds = (diff <= rdist)
+ # translate to indices
+ non_varying_inds = t(seq(1,ncol(x))) * non_varying_inds
+ # remove the ones that do vary
+ non_varying_inds = removeEmpty(target=non_varying_inds, margin="cols")
+}
+
+# Prepares a boolean mask for removing features according to permutaion.
+# The resulting matrix needs to be inflated to a sample set by using
prepare_samples_from_mask() before calling the model.
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# permutation A single permutation of varying features.
+# If using partitions, remove them beforhand by using
removePartitionsFromFeatures() from the utils.
+# n_non_varying_inds The number of feature that do not vary in the background
data.
+# Can be retrieved e.g. by looking at std.dev
+# partitions Matrix with first elemnt of partition in first row and
last element of partition in second row.
+# Used to treat partitions as one feature when creating
masks. Useful for one-hot-encoded features.
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# mask Boolean mask.
+# -----------------------------------------------------------------------------
+prepare_mask_for_permutation = function(Matrix[Double] permutation, Integer
n_non_varying_inds=0,
+ Matrix[Double] partitions=as.matrix(-1))
+return (Matrix[Double] masks){
+ if(sum(partitions)!=-1){
+ #can't use n_non_varying_inds and partitions at the same time
+ if(n_non_varying_inds > 0){
+ stop("shap-explainer::prepare_mask_for_permutation:ERROR: Can't use
n_non_varying_inds and partitions at the same time.")
+ }
+ #number of features not in permutation is diff between start and end of
partitions, since first feature remains in permutation
+ skip_inds = partitions[2,] - partitions[1,]
+
+ #skip these inds by treating them as non varying
+ n_non_varying_inds = sum(skip_inds)
+ }
+
+ #total number of features
+ perm_len = ncol(permutation)+n_non_varying_inds
+ if(n_non_varying_inds > 0){
+ #prep full constructor with placeholders
+ mask_constructor = matrix(perm_len+1, rows=1, cols = perm_len)
+ mask_constructor[1,1:ncol(permutation)] = permutation
+ }else{
+ mask_constructor=permutation
+ }
+
+ perm_cols = ncol(mask_constructor)
+
+ # we compute mask on reverse permutation wnd reverse it later to get desired
shape
+
+ # create row indicator vector ctable
+ perm_mask_rows = seq(1,perm_cols)
+ #TODO: col-vector and matrix mult?
+ perm_mask_rows = matrix(1, rows=perm_cols, cols=perm_cols) * perm_mask_rows
+ perm_mask_rows = lower.tri(target=perm_mask_rows, diag=TRUE, values=TRUE)
+ perm_mask_rows = removeEmpty(target=matrix(perm_mask_rows, rows=1,
cols=length(perm_mask_rows)), margin="cols")
+
+ # create column indicator for ctable
+ rev_permutation = t(rev(t(mask_constructor)))
+ #TODO: col-vector and matrix mult?
+ perm_mask_cols = matrix(1, rows=perm_cols, cols=perm_cols) * mask_constructor
+ perm_mask_cols = lower.tri(target=perm_mask_cols, diag=TRUE, values=TRUE)
+ perm_mask_cols = removeEmpty(target = matrix(perm_mask_cols,
cols=length(perm_mask_cols), rows=1), margin="cols")
+ #ctable
+ masks = table(perm_mask_rows, perm_mask_cols, perm_len, perm_len)
+ if(n_non_varying_inds > 0){
+ #truncate non varying rows
+ masks = masks[1:ncol(permutation)]
+
+ #replicate mask from first feature of each partionton to entire partitions
+ if(sum(partitions)!=-1){
+ for ( i in 1:ncol(partitions) ){
+ p_start = as.scalar(partitions[1,i])
+ p_end = as.scalar(partitions[2,i])
+ proxy = masks[,p_start] %*% matrix(1, rows=1, cols=p_end-p_start)
+ masks[,p_start+1:p_end] = proxy
+ }
+ }
+ }
+
+ # add inverted mask and revert order for desired shape for forward and
backward pass
+ masks = rbind(!masks[nrow(masks)],masks, rev(!masks[1:nrow(masks)-1]))
+}
+
+# Prepares the full mask for marginalization by repeating the rows
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# mask Boolean mask with 1, where from x, and 0, where
integrated over background data.
+# n_samples Number samples for which to replicate.
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# x_mask_full A replicated mask.
+# -----------------------------------------------------------------------------
+prepare_full_mask = function(Matrix[Double] mask, Integer n_samples)
+ return (Matrix[Double] x_mask_full){
+ x_mask_full = u_repeatRows(mask,n_samples)
+}
+
+# Prepares the masked background by replicating the samples and masking them
using the full mask.
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# x_mask_full Boolean mask replicated orw-wise.
+# X_bg_samples Samples from background. Either the same n samples
for all permutaions or
+# n*p samples, so each permutation has its own samples.
+# n_perms_in_samples Number of sample sets to identify block which need
to be replicated in X_bg_samples.
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# x_mask_full A replicated mask.
+# -----------------------------------------------------------------------------
+prepare_masked_X_bg = function(Matrix[Double] x_mask_full, Matrix[Double]
X_bg_samples, Integer n_perms_in_samples)
+return (Matrix[Double] masked_X_bg){
+ #Repeat background once for every row in original mask.
+ #If the same samples are used for each permutation, simply repeat the entire
samples accordingly
+ if (n_perms_in_samples <= 1){
+ #Since x_mask_full was already replicated row-wise by the number of rows
in X_bg_samples, we devide by it.
+ masked_X_bg = u_repeatMatrix(X_bg_samples,
nrow(x_mask_full)/nrow(X_bg_samples))
+ }else{
+ # if X_bg_samples has independent samples for each perm, it holds
n_samples*n_perms rows.
+ block_size = nrow(X_bg_samples)/n_perms_in_samples
+ masked_X_bg = u_repeatMatrixBlocks(X_bg_samples, block_size,
nrow(x_mask_full)/block_size/n_perms_in_samples)
+ }
+
+ masked_X_bg = masked_X_bg * !x_mask_full
+}
+
+# Applies the masked background and boolen mask to individual instance of
interest.
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# x_row Instance of interest as row-vector.
+# x_mask_full Boolean mask replicated orw-wise.
+# masked_X_bg Prepared background samples.
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# X_masked Set of synthesized instances for x_row.
+# -----------------------------------------------------------------------------
+apply_full_mask = function(Matrix[Double] x_row, Matrix[Double] x_mask_full,
Matrix[Double] masked_X_bg)
+return (Matrix[Double] X_masked){
+ #add the masked data from this row
+ X_masked = masked_X_bg + (x_mask_full * x_row)
+}
+
+# Removes all rows from the prepared masks and background data whenever their
feature is marked as non-varying.
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# masks Prepared and replicated mask for a singel instance.
+# masked_X_bg Prepared and replicated background data.
+# full_permutations The permutations from which the masks and bd data were
created.
+# non_var_inds A row-vector containiing the indices that were found to
be not varying for this instance.
+# n_samples The number samples over which each row is integarted.
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# sub_mask A subset of masks where for each permutation the rows
that correspond to
+# non-varying features are removed.
+# sub_masked_X_bg A subset of the background data where for each
permutation the rows that correspond to
+# non-varying features are removed.
+# -----------------------------------------------------------------------------
+remove_inds = function(Matrix[Double] masks, Matrix[Double] masked_X_bg,
Matrix[Double] full_permutations,
+ Matrix[Double] non_var_inds, Integer n_samples)
+return(Matrix[Double] sub_mask, Matrix[Double] sub_masked_X_bg){
+ offsets = seq(0,length(full_permutations)-ncol(full_permutations),
ncol(full_permutations))
+
+ ###
+ # get row indices from permutations
+ total_row_index = full_permutations + offsets
+ total_row_index = matrix(total_row_index, rows=length(total_row_index),
cols=1)
+
+ row_index = toOneHot(total_row_index, nrow(total_row_index))
+ ####
+ # get indices for all permutations as boolean mask
+ # repeat inds for every permutation
+ non_var_inds = matrix(1, rows=nrow(full_permutations),
cols=ncol(non_var_inds)) * non_var_inds
+ #add offset
+ non_var_total = non_var_inds + offsets
+ #reshape into col-vec
+ non_var_total = matrix(non_var_total,rows=length(non_var_total), cols=1,
byrow=FALSE)
+ non_var_mask = toOneHot(non_var_total, nrow(total_row_index))
+
+ non_var_mask = colSums(non_var_mask)
+
+ ###
+ # multiply to get mask
+ non_var_rows = row_index %*% t(non_var_mask)
+
+ ####
+ # unfold to full mask length
+ # reshape to add for each permutations
+ reshaped_rows = matrix(non_var_rows, rows=ncol(full_permutations),
cols=nrow(full_permutations), byrow=FALSE)
+
+ reshaped_rows_full = matrix(0,rows=1,cols=ncol(reshaped_rows))
+
+ #rbind to manipulate all perms at once
+ if( sum(reshaped_rows[nrow(reshaped_rows)]) > 0 ){
+ #fix last row issue by setting last zero to one, if 1 in last row
+ row_indicator = (!reshaped_rows) * seq(1, nrow(reshaped_rows), 1)
+ row_indicator = colMaxs(row_indicator)
+ row_indicator = t(toOneHot(t(row_indicator), nrow(reshaped_rows)))
+ reshaped_rows_2 = reshaped_rows[1:nrow(reshaped_rows)-1] +
row_indicator[1:nrow(reshaped_rows)-1]
+ reshaped_rows_full =
rbind(reshaped_rows_full,reshaped_rows,reshaped_rows_2)
+ }else{
+ reshaped_rows_full =
rbind(reshaped_rows_full,reshaped_rows,reshaped_rows[1:nrow(reshaped_rows)-1])
+ }
+ #reshape into col-vec
+ non_var_total = matrix(reshaped_rows_full, rows=length(reshaped_rows_full),
cols=1, byrow=FALSE)
+
+ #replicate, if masks already replicated
+ if (n_samples > 1){
+ non_var_total = matrix(1, rows=nrow(non_var_total), cols=n_samples) *
non_var_total
+ non_var_total = matrix(non_var_total, rows=length(non_var_total), cols=1)
+ }
+
+ #remove from mask according to this vector
+ sub_mask = removeEmpty(target=masks, select=!non_var_total, margin="rows")
+ #set to 1 where non varying
+ #sub_mask = removed_short_mask | non_var_mask[1, 1:ncol(removed_short_mask)]
+ sub_masked_X_bg = removeEmpty(target=masked_X_bg, select=!non_var_total,
margin="rows")
+}
+
+# Performs the integration/marginalization by computing means.
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# P Predictions from model.
+# n_samples Number of samples over which to take the mean.
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# P_means The means of the sample groups. Each row is one group
with means in cols.
+# -----------------------------------------------------------------------------
+compute_means_from_predictions = function(Matrix[Double] P, Integer n_samples)
+ return (Matrix[Double] P_means){
+ n_features = nrow(P)/n_samples
+
+ #transpose and reshape to concat all values of same type
+ # TODO: unneccessary for vectors, only t() would be needed
+ P = matrix(t(P), cols=1, rows=length(P))
+
+ #reshape, so all predictions from one batch are in one row
+ P = matrix(P, cols=n_samples, rows=length(P)/n_samples)
+
+ #compute row means
+ P_means = rowMeans(P)
+
+ # reshape and transpose to get back to input dimensions
+ P_means = matrix(P_means, rows=n_features, cols=length(P_means)/n_features)
+}
+
+# Computes phis from predictions for a permutation.
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# P Predictions for multiple permutations.
+# permutations Permutations to get the feature indices from.
+# non_var_inds Matrix holding the indices of non-varying features in
the permutation that were ignored
+# during prediction. These will be remove from the
<permutations> during computation of the phis.
+# n_partition_features Number of features that are in partitions - number of
partitions:
+# There is still one feature per partition kept in the
perms!
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# phis Phis or shapley values computed from this permutation.
+# Every row holds the phis for the corresponding feature.
+# -----------------------------------------------------------------------------
+compute_phis_from_prediction_means = function(Matrix[Double] P, Matrix[Double]
permutations,
+ Matrix[Double] non_var_inds=as.matrix(-1), Integer n_partition_features = 0)
+return(Matrix[Double] phis, Double expected){
+ perm_len=ncol(permutations)
+ n_non_var_inds = 0
+ partial_permutations = permutations
+
+ if(sum(non_var_inds)>0){
+ n_non_var_inds = ncol(non_var_inds)
+ #flatten perms to remove from all perms at once
+ perms_flattened = matrix(permutations, rows=length(permutations), cols=1)
+ rem_selector = outer(perms_flattened, non_var_inds, "==")
+ rem_selector = rowSums(rem_selector)
+ partial_permutations = removeEmpty(target=perms_flattened,
select=!rem_selector, margin="rows")
+ #reshape
+ partial_permutations = matrix(partial_permutations,
rows=perm_len-n_non_var_inds, cols=nrow(permutations))
+ perm_len = perm_len-n_non_var_inds
+ }
+
+ #reshape P to get one col per permutation
+ P_perm = matrix(P, rows=2*perm_len, cols=nrow(permutations), byrow=FALSE)
+
+ #forwards phis
+ forward_phis = P_perm[2:perm_len+1] - P_perm[1:perm_len]
+
+ #backward phis and fix first and last
+ backward_phis = rbind(P_perm[perm_len+2] - P_perm[1],
P_perm[perm_len+3:2*perm_len] - P_perm[perm_len+2:2*perm_len-1],
P_perm[perm_len+1] - P_perm[2*perm_len])
+ #reverse to match order of features in permutation
+ backward_phis = rev(backward_phis)
+ #avg forward and backward
+ forward_phis = matrix(forward_phis, rows=length(forward_phis), cols=1,
byrow=FALSE)
+ backward_phis = matrix(backward_phis, rows=length(backward_phis), cols=1,
byrow=FALSE)
+ avg_phis = (forward_phis + backward_phis) / 2
+
+ #aggregate to get only one phi per feature (and implicitly add zeros for non
var inds)
+ perms_flattened = matrix(partial_permutations,
rows=length(partial_permutations), cols=1)
+ phis = aggregate(target=avg_phis, groups=perms_flattened, fn="mean",
ngroups=ncol(permutations)+n_partition_features)
+
+ #get expected from first row
+ expected=mean(P_perm[1])
+}
+
+# Removes features that are part of a partition.
+# Keeps first feature of partition as proxy for partition.
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# features Matrix holding features in its cols.
+# partitions Matirx holding start and end of partitions in the cols of
the first and second row respectively.
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# short_features Matrix like fatures, but with the ones from partitiones
removed.
+# -----------------------------------------------------------------------------
+removePartitionsFromFeatures = function(Matrix[Double] features,
Matrix[Double] partitions)
+return (Matrix[Double] short_features){
+ #remove from features
+ rm_mask = matrix(0, rows=1, cols=ncol(features))
+ for (i in 1:ncol(partitions)){
+ part_start = as.scalar(partitions[1,i])
+ part_end = as.scalar(partitions[2,i])
+ #include part_start as proxy of partition
+ rm_mask = rm_mask + (features > part_start) * (features <= part_end)
+ }
+ short_features = removeEmpty(target=features, margin="cols", select=!rm_mask)
+}
+
+########################
+# Utility Functions that might be worth refactoring into its own file
+# They could be used in other scenarios as well
+########################
+
+
+# Samples from the background data X_bg.
+# The function first uses all background samples without replacement, but if
more samples are requested than
+# available in X_bg, it shuffles X_bg and pulls more samples from it, making
it sampling with replacement.
+# TODO: Might be replacable by other builtin for sampling in the future
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# X_bg Matrix of background data
+# samples Number of total samples
+# always_shuffle Boolean to enable reshuffleing of X_bg, defaults to false.
+# seed A seed for the shuffleing etc.
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# X_sample New Matrix containing #samples, from X_bg, potentially with
replacement.
+# -----------------------------------------------------------------------------
+u_sample_with_potential_replace = function(Matrix[Double] X_bg, Integer
samples, Boolean always_shuffle = 0, Integer seed)
+return (Matrix[Double] X_sample){
+ number_of_bg_samples = nrow(X_bg)
+
+ # expect to not use all from background and subsample from it
+ num_of_full_X_bg = 0
+ num_of_remainder_samples = samples
+
+ # shuffle background if desired
+ if(always_shuffle) {
+ X_bg = u_shuffle(X_bg)
+ }
+
+ # list to store references to generated matrices so we can rbind them in one
call
+ samples_list = list()
+
+ # in case we need more than in the background data, use it multiple times
with replacement
+ if(samples >= number_of_bg_samples) {
+ u_printShapMessage("WARN: More samples ("+toString(samples)+") are
requested than available in the background dataset
("+toString(number_of_bg_samples)+"). Using replacement", 1)
+
+ # get number of full sets of background by integer division
+ num_of_full_X_bg = samples %/% number_of_bg_samples
+ # get remaining samples using modulo
+ num_of_remainder_samples = samples %% number_of_bg_samples
+
+ #use background data once
+ samples_list = append(samples_list, X_bg)
+
+ if(num_of_full_X_bg > 1){
+ # add shuffled versions of background data
+ for (i in 1:num_of_full_X_bg-1){
+ samples_list = append(samples_list, u_shuffle(X_bg))
+ }
+ }
+ }
+
+ # sample from background dataset for remaining samples
+ if (num_of_remainder_samples > 0){
+ # pick remaining samples
+ random_samples_indices = sample(number_of_bg_samples,
num_of_remainder_samples, seed)
+
+ #contingency table to pick rows by multiplication
+ R_cont = table(random_samples_indices, random_samples_indices,
number_of_bg_samples, number_of_bg_samples)
+
+ #pick samples by multiplication with contingency table of indices and
removing empty rows
+ samples_list = append(samples_list, removeEmpty(target=t(t(X_bg) %*%
R_cont), margin="rows"))
+ }
+
+
+ if ( length(samples_list) == 1){
+ #dont copy if only one matrix is in list, since this is a heavy hitter
+ X_sample = as.matrix(samples_list[1])
+ } else {
+ #single call to bind all generated samples into one large matrix
+ X_sample = rbind(samples_list)
+ }
+}
+
+# Simple utility function to shuffle (from shuffle.dml, but without storing to
file). Shuffles rows.
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# X Matrix to be shuffled
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# X_shuffled Matrix like X but ... shuffled...
+# -----------------------------------------------------------------------------
+u_shuffle = function(Matrix[Double] X)
+return (Matrix[Double] X_shuffled){
+ num_col = ncol(X)
+ # Random vector used to shuffle the dataset
+ y = rand(rows=nrow(X), cols=1, min=0, max=1, pdf="uniform")
+ X = order(target = cbind(X, y), by = num_col + 1)
+ X_shuffled = X[,1:num_col]
+}
+
+# Simple utility function to create a range of integers from start to end.
+#
+# INPUT:
+#
---------------------------------------------------------------------------------------
+# start First integer of range.
+# stop First integer of range.
+#
---------------------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# range Matrix with range from start to end in its cols.
+# -----------------------------------------------------------------------------
+u_range = function(Integer start, Integer end)
+return (Matrix[Double] range){
+ range = t(cumsum(matrix(1, rows=end-start+1, cols=1)))
+ range = range+start-1
+}
+
+# Replicates rows of the input matrix n-times.
+#
+# Example:
+# [1,2]
+# [3,4]
+# becomes
+# [1,2]
+# [1,2]
+# [3,4]
+# [3,4]
+#
+# INPUT:
+# -----------------------------------------------------------------------------
+# M Matrix where rows will be replicated.
+# n_times Number of replications.
+# -----------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# M Matrix of replicated rows.
+# -----------------------------------------------------------------------------
+u_repeatRows = function(Matrix[Double] M, Integer n_times)
+return(Matrix[Double] M){
+ #get indices for new rows (e.g. 1,1,1,2,2,2 for 2 rows, each replicated 3
times)
+ indices = ceil(seq(1,nrow(M)*n_times,1) / n_times)
+
+ #to one hot, so we get a replication matrix R
+ R = toOneHot(indices, nrow(M))
+
+ #matrix-mulitply to repeat rows
+ M = R %*% M
+}
+
+# Replicates matrix n-times block-wise.
+#
+# Example:
+# [1,2]
+# [3,4]
+# becomes
+# [1,2]
+# [3,4]
+# [1,2]
+# [3,4]
+#
+# INPUT:
+# -----------------------------------------------------------------------------
+# M Matrix where rows will be replicated.
+# n_times Number of replications.
+# -----------------------------------------------------------------------------
+#
+# OUTPUT:
+# -----------------------------------------------------------------------------
+# M Matrix of replicated rows.
+# -----------------------------------------------------------------------------
+u_repeatMatrix = function(Matrix[Double] M, Integer n_times)
+return(Matrix[Double] M){
+ n_rows=nrow(M)
+ n_cols=ncol(M)
+ #reshape to row vector
+ M = matrix(M, rows=1, cols=length(M))
+ #broadcast
+ M = matrix(1, rows=n_times, cols=1) * M
+ #reshape to get matrix
+ M = matrix(M, rows=n_rows*n_times, cols=n_cols)
+}
+
+# Like repeatMatrix(), but alows to define parts of matrix as blocks to
replicate n-rows as a block.
+u_repeatMatrixBlocks = function(Matrix[Double] M, Integer rows_per_block,
Integer n_times)
+return(Matrix[Double] M){
+ n_rows=nrow(M)
+ n_cols=ncol(M)
+ #reshape to row vector
+ M = matrix(M, rows=n_rows/rows_per_block, cols=n_cols*rows_per_block)
+ #repeat block rows
+ M = u_repeatRows(M, n_times)
+ #reshape to get matrix
+ M = matrix(M, rows=n_rows*n_times, cols=n_cols)
+}
+
+#utility function to print with shap-explainer-tag
+u_printShapMessage = function(String message, Boolean verbose){
+ if(verbose){
+ print("shap-explainer::"+message)
+ }
+}
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 98f92ae55e..ca31cd331a 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -300,6 +300,7 @@ public enum Builtins {
SELVARTHRESH("selectByVarThresh", true),
SEQ("seq", false),
SYMMETRICDIFFERENCE("symmetricDifference", true),
+ SHAPEXPLAINER("shapExplainer", true),
SHERLOCK("sherlock", true),
SHERLOCKPREDICT("sherlockPredict", true),
SHORTESTPATH("shortestPath", true),
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinShapExplainerTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinShapExplainerTest.java
new file mode 100644
index 0000000000..0c207d66f3
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinShapExplainerTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part2;
+
+
+import org.junit.Test;
+
+import java.util.HashMap;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class BuiltinShapExplainerTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME = "shapExplainer";
+ private static final String TEST_DIR = "functions/builtin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinShapExplainerTest.class.getSimpleName() + "/";
+
+ //FIXME need for padding result with zero
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"R"}));
+ }
+
+ @Test
+ public void testPrepareMaskForPermutation() {
+ runShapExplainerUnitTest("prepare_mask_for_permutation");
+ }
+
+ @Test
+ public void testPrepareMaskForPartialPermutation() {
+
runShapExplainerUnitTest("prepare_mask_for_partial_permutation");
+ }
+
+ @Test
+ public void testPrepareMaskForPartitionedPermutation() {
+
runShapExplainerUnitTest("prepare_mask_for_partitioned_permutation");
+ }
+
+ @Test
+ public void testComputeMeansFromPredictions() {
+ runShapExplainerUnitTest("compute_means_from_predictions");
+ }
+
+ @Test
+ public void testComputePhisFromPredictionMeans() {
+ runShapExplainerUnitTest("compute_phis_from_prediction_means");
+ }
+
+ @Test
+ public void testComputePhisFromPredictionMeansNonVars() {
+
runShapExplainerUnitTest("compute_phis_from_prediction_means_non_vars");
+ }
+
+ @Test
+ public void testPrepareFullMask() {
+ runShapExplainerUnitTest("prepare_full_mask");
+ }
+
+ @Test
+ public void testPrepareMaskedXBg() {
+ runShapExplainerUnitTest("prepare_masked_X_bg");
+ }
+
+ @Test
+ public void testPrepareMaskedXBgIndependentPerms() {
+
runShapExplainerUnitTest("prepare_masked_X_bg_independent_perms");
+ }
+
+ @Test
+ public void testApplyFullMask() {
+ runShapExplainerUnitTest("apply_full_mask");
+ }
+
+ private void runShapExplainerUnitTest(String testType) {
+ ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ //execute given unit test
+ fullDMLScriptName = HOME + TEST_NAME + "Unit.dml";
+ programArgs = new String[]{"-args", testType,
output("R"), output("R_expected")};
+ runTest(true, false, null, -1);
+
+ //compare to expected result
+ HashMap<CellIndex, Double> result =
readDMLMatrixFromOutputDir("R");
+ HashMap<CellIndex, Double> result_expected =
readDMLMatrixFromOutputDir("R_expected");
+
+ TestUtils.compareMatrices(result, result_expected,
1e-3, testType+"_result", testType+"_expected");
+
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+
+ @Test
+ public void testShapExplainerDummyData(){
+ runShapExplainerComponentTest(false);
+ }
+ //TODO add test with real data
+
+ private void runShapExplainerComponentTest(Boolean useRealData) {
+ ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+ String HOME = SCRIPT_DIR + TEST_DIR;
+
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ //execute given unit test
+ fullDMLScriptName = HOME + TEST_NAME + "Component.dml";
+ programArgs = new String[]{"-args", output("R"),
output("R_expected")};
+ runTest(true, false, null, -1);
+
+ //compare to expected phis
+ HashMap<CellIndex, Double> result =
readDMLMatrixFromOutputDir("R_phis");
+ HashMap<CellIndex, Double> result_expected =
readDMLMatrixFromOutputDir("R_expected_phis");
+
+ TestUtils.compareMatrices(result, result_expected,
1e-3, "explainer_result_phis", "explainer_expected_phis");
+
+ //compare to expected value of model
+ HashMap<CellIndex, Double> result_e =
readDMLMatrixFromOutputDir("R_e");
+ HashMap<CellIndex, Double> result_expected_e =
readDMLMatrixFromOutputDir("R_expected_e");
+
+ TestUtils.compareMatrices(result_e, result_expected_e,
1e-3, "explainer_result_e", "explainer_expected_e");
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/builtin/shapExplainerComponent.dml
b/src/test/scripts/functions/builtin/shapExplainerComponent.dml
new file mode 100644
index 0000000000..8bf444b665
--- /dev/null
+++ b/src/test/scripts/functions/builtin/shapExplainerComponent.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+########################################################################################################
+# THIS TEST IS HIGHLY DEPENDANT ON THE SAMPING!
+# Changes in the dataset or number of samples etc. migh already be enough to
change the expected result.
+########################################################################################################
+
+model_args = list(mult=1)
+x_instances = matrix("100 200 300 100 300 400 100 100 500", rows=3,
cols=3)
+X_bg = matrix("11 12 13 21 22 23 31 32 33 41 42 43", rows=4, cols=3)
+n_permutations = 2
+n_samples = 3
+seed = 42
+
+#model for explainer test
+dummyModel = function(Matrix[Double] X, Double mult)
+ return(Matrix[Double] P){
+ P = rowSums(X)*mult
+}
+
+[result_phis, result_e] = shapExplainer("dummyModel", model_args, x_instances,
X_bg, n_permutations, n_samples, 0, as.matrix(-1), seed, 1)
+result_e = cbind(as.matrix(result_e), as.matrix(0))
+#TODO for some reason storing just the scalar results in errors, so we create
a small matrix by padding with a zero.
+# Might be due to comma vs dot separation of decimals in strings if systems
uses german local or other.
+
+expected_result_phis = matrix("69 168 267 69 268 367 69 68 467", rows=3,
cols=3)
+expected_result_e = matrix("96 0", rows=1, cols=2)
+
+path_phis=$1+"_phis"
+path_e=$1+"_e"
+path_expected_phis=$2+"_phis"
+path_expected_e=$2+"_e"
+
+write(result_phis, path_phis)
+write(result_e, path_e)
+write(expected_result_phis, path_expected_phis)
+write(expected_result_e, path_expected_e)
+
diff --git a/src/test/scripts/functions/builtin/shapExplainerUnit.dml
b/src/test/scripts/functions/builtin/shapExplainerUnit.dml
new file mode 100644
index 0000000000..c5f227a67e
--- /dev/null
+++ b/src/test/scripts/functions/builtin/shapExplainerUnit.dml
@@ -0,0 +1,104 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+source("scripts/builtin/shapExplainer.dml") as shap;
+
+if ($1 == 'prepare_mask_for_permutation') {
+ #prepare_mask_for_permutation
+ perm = matrix("3 1 2", cols=3, rows=1)
+ result = shap::prepare_mask_for_permutation(permutation=perm)
+ expected_result = matrix("0 0 0 0 0 1 1 0 1 1 1 1 0 1 0 1 1 0", rows=6,
cols=3)
+
+} else if ($1 == 'prepare_mask_for_partial_permutation') {
+ #prepare_mask_for_partial_permutation
+ perm = matrix("4 1 2", cols=3, rows=1)
+ result = shap::prepare_mask_for_permutation(permutation=perm,
n_non_varying_inds=2)
+ expected_result = matrix("0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 0
1 1 0 1 1 1 1 0 1", rows=6, cols=5)
+
+} else if ($1 == 'prepare_mask_for_partitioned_permutation') {
+ #prepare_mask_for_partitioned_permutation
+ perm = matrix("4 1 2", cols=3, rows=1)
+ partitions = matrix("2 4 3 5", cols=2, rows=2)
+ result = shap::prepare_mask_for_permutation(permutation=perm,
partitions=partitions)
+ expected_result = matrix("0 0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 1 1 1 1 0
1 1 0 0 1 1 1 0 0", rows=6, cols=5)
+
+} else if ($1 == 'compute_means_from_predictions') {
+ #compute_means_from_predictions
+ p = matrix("2 3 3 4 4 5", rows=6, cols=1)
+ result = shap::compute_means_from_predictions(p, 2)
+ expected_result = matrix("2.5 3.5 4.5", rows=3, cols=1)
+
+} else if ($1 == 'compute_phis_from_prediction_means') {
+ #compute_phis_from_prediction_means
+ permutation = matrix("2 3 4 1 5", cols=5, rows=1)
+ P_perm = matrix("10 21 22 23 24 100 31 32 33 34", rows=10, cols=1)
+ result = shap::compute_phis_from_prediction_means(P=P_perm,
permutations=permutation)
+ expected_result = matrix("1 38.5 1 1 48.5", rows=5, cols=1)
+
+} else if ($1 == 'compute_phis_from_prediction_means_non_vars') {
+ #compute_phis_from_prediction_means with non varying inds
+ permutation = matrix("3 4 2 1 5", cols=5, rows=1)
+ non_varying_inds= matrix("2", rows=1, cols=1)
+ P_perm = matrix("10 22 23 24 100 31 32 33", rows=8, cols=1)
+ result = shap::compute_phis_from_prediction_means(P=P_perm,
permutations=permutation, non_var_inds=non_varying_inds)
+ expected_result = matrix("1 0 39.5 1 48.5", rows=5, cols=1)
+
+} else if ($1 == 'prepare_full_mask') {
+ #prepare_full_mask
+ mask = matrix("1 0 0 1", rows=2, cols=2)
+ result = shap::prepare_full_mask(mask, 3)
+ result = shap::u_repeatRows(mask,3)
+ expected_result = matrix("1 0 1 0 1 0 0 1 0 1 0 1", rows=6, cols=2)
+
+} else if ($1 == 'prepare_masked_X_bg') {
+ #prepare_masked_X_bg
+ mask = matrix("1 0 0 1", rows=2, cols=2)
+ full_mask = shap::prepare_full_mask(mask, 3)
+ X_bg_samples = matrix("11 12 21 22 31 32", rows=3, cols=2)
+ result = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 0)
+ expected_result = matrix("0 12 0 22 0 32 11 0 21 0 31 0", rows=6, cols=2)
+
+} else if ($1 == 'prepare_masked_X_bg_independent_perms') {
+ #prepare_masked_X_bg for independent perms
+ mask = matrix("1 0 0 1 1 0 0 1", rows=4, cols=2)
+ full_mask = shap::prepare_full_mask(mask, 3)
+ X_bg_samples = matrix("11 12 21 22 31 32 41 42 51 52 61 62", rows=6,
cols=2)
+ result = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 2)
+ expected_result = matrix("0 12 0 22 0 32 11 0 21 0 31 0 0 42 0 52 0 62 41
0 51 0 61 0", rows=12, cols=2)
+
+} else if ($1 == 'apply_full_mask') {
+ #apply_full_mask
+ x_row = matrix("100 200", rows=1, cols=2)
+ mask = matrix("1 0 0 1", rows=2, cols=2)
+ full_mask = shap::prepare_full_mask(mask, 3)
+ X_bg_samples = matrix("11 12 21 22 31 32", rows=3, cols=2)
+ masked_X_bg = shap::prepare_masked_X_bg(full_mask, X_bg_samples, 0)
+ result = shap::apply_full_mask(x_row, full_mask, masked_X_bg)
+ expected_result = matrix("100 12 100 22 100 32 11 200 21 200 31 200",
rows=6, cols=2)
+
+} else {
+ print("Test type "+$1+" unknown.")
+ result = matrix("100 100", rows=1, cols=2)
+ expected_result = matrix("0 0", rows=1, cols=2)
+}
+
+write(result, $2)
+write(expected_result, $3)
+