andishgar commented on code in PR #47586:
URL: https://github.com/apache/arrow/pull/47586#discussion_r2426542780


##########
cpp/src/arrow/sparse_tensor.cc:
##########
@@ -475,4 +523,292 @@ Result<std::shared_ptr<Tensor>> 
SparseTensor::ToTensor(MemoryPool* pool) const {
   }
 }
 
+namespace {
+
+struct SparseTensorValidatorBase {
+  SparseTensorValidatorBase(const Tensor& tensor, const SparseTensor& 
sparse_tensor)
+      : tensor(tensor), sparse_tensor(sparse_tensor) {}
+
+  template <typename ValueType>
+  Status ValidateValue(typename ValueType::c_type sparse_tensor_value,
+                       typename ValueType::c_type tensor_value) {
+    if (!internal::is_not_zero<ValueType>(sparse_tensor_value)) {
+      return Status::Invalid("Sparse tensor values must be non-zero");
+    } else if (sparse_tensor_value != tensor_value) {
+      if constexpr (is_floating_type<ValueType>::value) {
+        if (!std::isnan(tensor_value) || !std::isnan(sparse_tensor_value)) {
+          return Status::Invalid(
+              "Inconsistent values between sparse tensor and dense tensor");
+        }
+      } else {
+        return Status::Invalid(
+            "Inconsistent values between sparse tensor and dense tensor");
+      }
+    }
+    return Status::OK();
+  }
+
+  const Tensor& tensor;
+  const SparseTensor& sparse_tensor;
+};
+
+struct SparseCOOValidator : public SparseTensorValidatorBase {
+  using SparseTensorValidatorBase::SparseTensorValidatorBase;
+
+  Status Validate() {
+    auto sparse_coo_index =
+        
internal::checked_pointer_cast<SparseCOOIndex>(sparse_tensor.sparse_index());
+    auto indices = sparse_coo_index->indices();
+    RETURN_NOT_OK(CheckSparseCOOIndexValidity(indices->type(), 
indices->shape(),
+                                              indices->strides()));
+    // Validate Values
+    return util::VisitCOOTensorType(*sparse_tensor.type(), *indices->type(), 
*this);
+  }
+
+  template <typename ValueType, typename IndexType>
+  Status operator()(const ValueType& value_type, const IndexType& index_type) {
+    return ValidateSparseCooTensorValues(value_type, index_type);
+  }
+
+  template <typename ValueType, typename IndexType>
+  Status ValidateSparseCooTensorValues(const ValueType&, const IndexType&) {
+    using IndexCType = typename IndexType::c_type;
+    using ValueCType = typename ValueType::c_type;
+
+    auto sparse_coo_index =
+        
internal::checked_pointer_cast<SparseCOOIndex>(sparse_tensor.sparse_index());
+    auto sparse_coo_values_buffer = sparse_tensor.data();
+
+    const auto& indices = sparse_coo_index->indices();
+    const auto* indices_data = 
sparse_coo_index->indices()->data()->data_as<IndexCType>();
+    const auto* sparse_coo_values = 
sparse_coo_values_buffer->data_as<ValueCType>();
+
+    ARROW_ASSIGN_OR_RAISE(auto non_zero_count, tensor.CountNonZero());
+
+    if (indices->shape()[0] != non_zero_count) {
+      return Status::Invalid("Mismatch between non-zero count in sparse tensor 
(",
+                             indices->shape()[0], ") and dense tensor (", 
non_zero_count,
+                             ")");
+    } else if (indices->shape()[1] != 
static_cast<int64_t>(tensor.shape().size())) {
+      return Status::Invalid("Mismatch between coordinate dimension in sparse 
tensor (",
+                             indices->shape()[1], ") and tensor shape (",
+                             tensor.shape().size(), ")");
+    }
+
+    auto coord_size = indices->shape()[1];
+    std::vector<int64_t> coord(coord_size);
+    for (int64_t i = 0; i < indices->shape()[0]; i++) {
+      for (int64_t j = 0; j < coord_size; j++) {
+        coord[j] = static_cast<int64_t>(indices_data[i * coord_size + j]);
+      }
+      ARROW_RETURN_NOT_OK(
+          ValidateValue<ValueType>(sparse_coo_values[i], 
tensor.Value<ValueType>(coord)));
+    }
+
+    return Status::OK();
+  }
+};
+
+template <typename SparseCSXIndex>
+struct SparseCSXValidator : public SparseTensorValidatorBase {
+  SparseCSXValidator(const Tensor& tensor, const SparseTensor& sparse_tensor)
+      : SparseTensorValidatorBase(tensor, sparse_tensor) {
+    sparse_csx_index =
+        
internal::checked_pointer_cast<SparseCSXIndex>(sparse_tensor.sparse_index());
+  }
+
+  Status Validate() {
+    auto indptr = sparse_csx_index->indptr();
+    auto indices = sparse_csx_index->indices();
+    ARROW_RETURN_NOT_OK(
+        internal::ValidateSparseCSXIndex(indptr->type(), indices->type(), 
indptr->shape(),
+                                         indices->shape(), 
sparse_csx_index->kTypeName));
+    return util::VisitCSXType(*sparse_tensor.type(), *indices->type(), 
*indptr->type(),
+                              *this);
+  }
+
+  template <typename ValueType, typename IndexType, typename IndexPointerType>
+  Status operator()(const ValueType& value_type, const IndexType& index_type,
+                    const IndexPointerType& index_pointer_type) {

Review Comment:
   In this operation, it’s not possible, since the `IndexPointerType` and 
`IndexType` can be of different types. For clarity, see the [IPC format 
here](https://github.com/apache/arrow/blob/bfce5f208e2470648fb9b1a47d0e6521a278efaf/format/SparseTensor.fbs#L85-L124)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to