corepointer commented on a change in pull request #1101:
URL: https://github.com/apache/systemds/pull/1101#discussion_r523475631
##########
File path: src/main/cpp/systemds.cpp
##########
@@ -68,57 +51,62 @@
// -------------------------------------------------------------------
-int maxThreads = -1;
+//int maxThreads = -1;
JNIEXPORT void JNICALL
Java_org_apache_sysds_utils_NativeHelper_setMaxNumThreads
(JNIEnv *, jclass, jint jmaxThreads) {
- maxThreads = (int) jmaxThreads;
+// maxThreads = (int) jmaxThreads;
+ setNumThreadsForBLAS(jmaxThreads);
}
-JNIEXPORT jboolean JNICALL Java_org_apache_sysds_utils_NativeHelper_dmmdd(
+JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_dmmdd(
JNIEnv* env, jclass cls, jdoubleArray m1, jdoubleArray m2, jdoubleArray
ret,
jint m1rlen, jint m1clen, jint m2clen, jint numThreads)
{
double* m1Ptr = GET_DOUBLE_ARRAY(env, m1, numThreads);
double* m2Ptr = GET_DOUBLE_ARRAY(env, m2, numThreads);
double* retPtr = GET_DOUBLE_ARRAY(env, ret, numThreads);
if(m1Ptr == NULL || m2Ptr == NULL || retPtr == NULL)
- return (jboolean) false;
+ return -1;
dmatmult(m1Ptr, m2Ptr, retPtr, (int)m1rlen, (int)m1clen, (int)m2clen,
(int)numThreads);
+ size_t nnz = computeNNZ<double>(retPtr, m1rlen * m2clen);
RELEASE_INPUT_ARRAY(env, m1, m1Ptr, numThreads);
RELEASE_INPUT_ARRAY(env, m2, m2Ptr, numThreads);
RELEASE_ARRAY(env, ret, retPtr, numThreads);
- return (jboolean) true;
+
+ return static_cast<jlong>(nnz);
}
-JNIEXPORT jboolean JNICALL Java_org_apache_sysds_utils_NativeHelper_smmdd(
+JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_smmdd(
JNIEnv* env, jclass cls, jobject m1, jobject m2, jobject ret,
jint m1rlen, jint m1clen, jint m2clen, jint numThreads)
{
float* m1Ptr = (float*) env->GetDirectBufferAddress(m1);
float* m2Ptr = (float*) env->GetDirectBufferAddress(m2);
float* retPtr = (float*) env->GetDirectBufferAddress(ret);
if(m1Ptr == NULL || m2Ptr == NULL || retPtr == NULL)
- return (jboolean) false;
+ return -1;
smatmult(m1Ptr, m2Ptr, retPtr, (int)m1rlen, (int)m1clen, (int)m2clen,
(int)numThreads);
- return (jboolean) true;
+ return static_cast<jlong>(computeNNZ<float>(retPtr, m1rlen * m2clen));
}
-JNIEXPORT jboolean JNICALL Java_org_apache_sysds_utils_NativeHelper_tsmm
+JNIEXPORT jlong JNICALL Java_org_apache_sysds_utils_NativeHelper_tsmm
(JNIEnv * env, jclass cls, jdoubleArray m1, jdoubleArray ret, jint m1rlen,
jint m1clen, jboolean leftTrans, jint numThreads) {
double* m1Ptr = GET_DOUBLE_ARRAY(env, m1, numThreads);
double* retPtr = GET_DOUBLE_ARRAY(env, ret, numThreads);
if(m1Ptr == NULL || retPtr == NULL)
- return (jboolean) false;
+ return -1;
Review comment:
jboolean is char, jlong is a singed integer (as is -1). Furthermore I
did not get a warning by any of the used compilers (MSVC and GCC)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]