Baunsgaard commented on code in PR #1990: URL: https://github.com/apache/systemds/pull/1990#discussion_r1468854244
########## scripts/builtin/tSNE.dml: ########## @@ -119,7 +161,7 @@ return(matrix[double] P) while (mean(abs(Hdiff)) > tol & itr < 50) { P = exp(-D * beta) P = P * ZERODIAG - sum_Pi = rowSums(P) + sum_Pi = rowSums(P) = 1e-12 Review Comment: this line confuses me, what is the intention? ########## scripts/builtin/tSNE.dml: ########## @@ -141,4 +183,4 @@ return(matrix[double] P) P = P / sum(P) if(is_verbose) print("x2p finishing....") -} +} Review Comment: add a new line in the end of the script ########## scripts/builtin/tSNE.dml: ########## @@ -72,6 +106,13 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity sumW = rowSums(W) g = Y * sumW - W %*% Y dY = momentum*dY - lr*g + + norm = sum(dY^2) + if(is_verbose & itr %%10 ==0){ Review Comment: maybe add a print iteration variable, to indicate how frequent the logging is and default it to 10. ########## src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinTSNETest.java: ########## @@ -403,4 +403,108 @@ private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr, rtplatform = platformOld; } } + + @Test + public void testTSNEEarlyStopping() throws IOException { + // Test setup guarantees early stopping. + runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-3, 1, "FALSE", ExecType.CP); + } + + @SuppressWarnings("unused") + private void runTSNEEarlyStoppingTest( + Integer reduced_dims, + Integer perplexity, + Double lr, + Double momentum, + Integer max_iter, + Double tol, + Integer seed, + String is_verbose, + ExecType instType) throws IOException { + + ExecMode platformOld = setExecMode(instType); + try + { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[]{ + "-nvargs", "X=" + input("X"), "Y=" + output("Y"), + "reduced_dims=" + reduced_dims, + "perplexity=" + perplexity, + "lr=" + lr, + "momentum=" + momentum, + "max_iter=" + max_iter, + "tol= " + tol, + "seed=" + seed, + "is_verbose=" + is_verbose}; + + // The Input values are calculated using the following dml script: + // X = rand(rows=50, cols=2, min=0, max=5, seed=1) + + // Input + double[][] X = { + {2.271495063217468, 3.917376227330574}, + {1.8027277767886734, 0.5602638182708702}, + {1.8307117742445955, 3.080459273928752}, + {3.960945849944864, 1.836254625202115}, + {3.150073237716238, 1.9323395742833234}, + {2.6433200695475314, 4.1796071244359805}, + {3.837027652831461, 1.8984862827162574}, + {3.3620223450187448, 4.221502623882378}, + {2.3282372847390254, 4.602696981602351}, + {1.063050408038052, 3.049136203059148}, + {4.945108528303021, 3.290728762588105}, + {0.03997874419356229, 4.78972783775991}, + {3.219940877253892, 0.4546090824785526}, + {3.661862179707895, 4.9115252981693205}, + {2.006763020664273, 1.6504573252270927}, + {4.802896313025078, 3.7058196696317185}, + {4.989560263975035, 3.3590878579410233}, + {0.2881957805129115, 2.7235626348446864}, + {4.205473623958116, 0.7513651648333092}, + {1.5030599075316982, 1.9059965151083047}, + {4.111690819873698, 4.38922550887249}, + {3.55235293843559, 2.7707785045249262}, + {3.5421273466628604, 3.218473690489352}, + {2.0021344008348603, 3.293397607562143}, + {0.6236309437993054, 4.690911049840824}, + {4.28743111141226, 3.058259024138692}, + {1.351324262063277, 1.4910437726755477}, + {2.328053099817537, 2.844624510685577}, + {2.058835681566319, 3.1365678249943336}, + {3.758610361307626, 1.0596733909061373}, + {4.4615463190110205, 4.67202160391804}, + {1.44939799230235, 0.3342638523743646}, + {4.299621130384286, 3.781441439604645}, + {4.671573038039089, 1.1565494768485123}, + {0.8624668449657552, 1.9085522899983942}, + {0.34305466410947616, 0.6344221672215061}, + {4.837399879571096, 4.391970748711334}, + {4.280838563730712, 3.3498259946465705}, + {0.9926830544799081, 4.198090879512748}, + {0.2809217637487471, 2.7963040315556564}, + {0.17872992178431912, 3.565772551292108}, + {4.148793911769612, 1.0757141044759506}, + {2.0111513617190186, 2.7646430913923767}, + {0.5114578168532041, 1.3708650661139115}, + {0.38545762498678526, 0.21277125305527278}, + {2.356200617781426, 2.20790000896965}, + {3.665608219962555, 3.399666975542729}, + {1.7618442622801385, 4.675570729512945}, + {4.987236193552888, 0.41700477957766546}, + {0.21496074278985922, 3.5781179414157616} + }; + + + writeInputMatrixWithMTD("X", X, true); + + runTest(true, false, null, -1); Review Comment: we need to test for behavior. This test only verify that the script runs. not that it works. you can maybe look at the print output and see that the iterations give lower and lower l1 norm? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: dev-unsubscr...@systemds.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org