Baunsgaard commented on code in PR #1990:
URL: https://github.com/apache/systemds/pull/1990#discussion_r1468854244


##########
scripts/builtin/tSNE.dml:
##########
@@ -119,7 +161,7 @@ return(matrix[double] P)
   while (mean(abs(Hdiff)) > tol & itr < 50) {
     P = exp(-D * beta)
     P = P * ZERODIAG
-    sum_Pi = rowSums(P)
+    sum_Pi = rowSums(P) = 1e-12

Review Comment:
   this line confuses me, what is the intention?



##########
scripts/builtin/tSNE.dml:
##########
@@ -141,4 +183,4 @@ return(matrix[double] P)
   P = P / sum(P)
   if(is_verbose)
     print("x2p finishing....")
-}
+}

Review Comment:
   add a new line in the end of the script



##########
scripts/builtin/tSNE.dml:
##########
@@ -72,6 +106,13 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 
2, Integer perplexity
     sumW = rowSums(W)
     g = Y * sumW - W %*% Y
     dY = momentum*dY - lr*g
+
+    norm = sum(dY^2)
+    if(is_verbose & itr %%10 ==0){

Review Comment:
   maybe add a print iteration variable, to indicate how frequent the logging 
is and default it to 10.



##########
src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java:
##########
@@ -403,4 +403,108 @@ private void runTSNETest(Integer reduced_dims, Integer 
perplexity, Double lr,
                        rtplatform = platformOld;
                }
        }
+
+       @Test
+       public void testTSNEEarlyStopping() throws IOException {
+               // Test setup guarantees early stopping.
+               runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-3, 1, 
"FALSE", ExecType.CP);
+       }
+
+       @SuppressWarnings("unused")
+       private void runTSNEEarlyStoppingTest(
+               Integer reduced_dims, 
+               Integer perplexity, 
+               Double lr,
+               Double momentum, 
+               Integer max_iter, 
+               Double tol, 
+               Integer seed, 
+               String is_verbose, 
+               ExecType instType) throws IOException {
+               
+               ExecMode platformOld = setExecMode(instType);
+               try
+               {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{
+                               "-nvargs", "X=" + input("X"), "Y=" + 
output("Y"),
+                               "reduced_dims=" + reduced_dims,
+                               "perplexity=" + perplexity,
+                               "lr=" + lr,
+                               "momentum=" + momentum,
+                               "max_iter=" + max_iter,
+                               "tol= " + tol,
+                               "seed=" + seed,
+                               "is_verbose=" + is_verbose};
+
+                       // The Input values are calculated using the following 
dml script:
+                       // X = rand(rows=50, cols=2, min=0, max=5, seed=1)
+
+                       // Input
+        double[][] X = {
+                {2.271495063217468, 3.917376227330574},
+                {1.8027277767886734, 0.5602638182708702},
+                {1.8307117742445955, 3.080459273928752},
+                {3.960945849944864, 1.836254625202115},
+                {3.150073237716238, 1.9323395742833234},
+                {2.6433200695475314, 4.1796071244359805},
+                {3.837027652831461, 1.8984862827162574},
+                {3.3620223450187448, 4.221502623882378},
+                {2.3282372847390254, 4.602696981602351},
+                {1.063050408038052, 3.049136203059148},
+                {4.945108528303021, 3.290728762588105},
+                {0.03997874419356229, 4.78972783775991},
+                {3.219940877253892, 0.4546090824785526},
+                {3.661862179707895, 4.9115252981693205},
+                {2.006763020664273, 1.6504573252270927},
+                {4.802896313025078, 3.7058196696317185},
+                {4.989560263975035, 3.3590878579410233},
+                {0.2881957805129115, 2.7235626348446864},
+                {4.205473623958116, 0.7513651648333092},
+                {1.5030599075316982, 1.9059965151083047},
+                {4.111690819873698, 4.38922550887249},
+                {3.55235293843559, 2.7707785045249262},
+                {3.5421273466628604, 3.218473690489352},
+                {2.0021344008348603, 3.293397607562143},
+                {0.6236309437993054, 4.690911049840824},
+                {4.28743111141226, 3.058259024138692},
+                {1.351324262063277, 1.4910437726755477},
+                {2.328053099817537, 2.844624510685577},
+                {2.058835681566319, 3.1365678249943336},
+                {3.758610361307626, 1.0596733909061373},
+                {4.4615463190110205, 4.67202160391804},
+                {1.44939799230235, 0.3342638523743646},
+                {4.299621130384286, 3.781441439604645},
+                {4.671573038039089, 1.1565494768485123},
+                {0.8624668449657552, 1.9085522899983942},
+                {0.34305466410947616, 0.6344221672215061},
+                {4.837399879571096, 4.391970748711334},
+                {4.280838563730712, 3.3498259946465705},
+                {0.9926830544799081, 4.198090879512748},
+                {0.2809217637487471, 2.7963040315556564},
+                {0.17872992178431912, 3.565772551292108},
+                {4.148793911769612, 1.0757141044759506},
+                {2.0111513617190186, 2.7646430913923767},
+                {0.5114578168532041, 1.3708650661139115},
+                {0.38545762498678526, 0.21277125305527278},
+                {2.356200617781426, 2.20790000896965},
+                {3.665608219962555, 3.399666975542729},
+                {1.7618442622801385, 4.675570729512945},
+                {4.987236193552888, 0.41700477957766546},
+                {0.21496074278985922, 3.5781179414157616}
+        };
+
+
+                       writeInputMatrixWithMTD("X", X, true);
+
+                       runTest(true, false, null, -1);

Review Comment:
   we need to test for behavior. 
   This test only verify that the script runs. not that it works.
   you can maybe look at the print output and see that the iterations give 
lower and lower l1 norm?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: dev-unsubscr...@systemds.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to