http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormatTest.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormatTest.java b/flink-staging/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormatTest.java deleted file mode 100644 index b76f8b8..0000000 --- a/flink-staging/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCInputFormatTest.java +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.api.java.io.jdbc; - -import java.io.IOException; -import java.sql.Connection; -import java.sql.DriverManager; -import java.sql.SQLException; -import java.sql.Statement; - -import org.junit.Assert; - -import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.api.java.tuple.Tuple5; -import org.junit.After; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -public class JDBCInputFormatTest { - JDBCInputFormat jdbcInputFormat; - - static Connection conn; - - static final Object[][] dbData = { - {1001, ("Java for dummies"), ("Tan Ah Teck"), 11.11, 11}, - {1002, ("More Java for dummies"), ("Tan Ah Teck"), 22.22, 22}, - {1003, ("More Java for more dummies"), ("Mohammad Ali"), 33.33, 33}, - {1004, ("A Cup of Java"), ("Kumar"), 44.44, 44}, - {1005, ("A Teaspoon of Java"), ("Kevin Jones"), 55.55, 55}}; - - @BeforeClass - public static void setUpClass() { - try { - prepareDerbyDatabase(); - } catch (Exception e) { - Assert.fail(); - } - } - - private static void prepareDerbyDatabase() throws ClassNotFoundException, SQLException { - System.setProperty("derby.stream.error.field", "org.apache.flink.api.java.record.io.jdbc.DevNullLogStream.DEV_NULL"); - String dbURL = "jdbc:derby:memory:ebookshop;create=true"; - Class.forName("org.apache.derby.jdbc.EmbeddedDriver"); - conn = DriverManager.getConnection(dbURL); - createTable(); - insertDataToSQLTable(); - conn.close(); - } - - private static void createTable() throws SQLException { - StringBuilder sqlQueryBuilder = new StringBuilder("CREATE TABLE books ("); - sqlQueryBuilder.append("id INT NOT NULL DEFAULT 0,"); - sqlQueryBuilder.append("title VARCHAR(50) DEFAULT NULL,"); - sqlQueryBuilder.append("author VARCHAR(50) DEFAULT NULL,"); - sqlQueryBuilder.append("price FLOAT DEFAULT NULL,"); - sqlQueryBuilder.append("qty INT DEFAULT NULL,"); - sqlQueryBuilder.append("PRIMARY KEY (id))"); - - Statement stat = conn.createStatement(); - stat.executeUpdate(sqlQueryBuilder.toString()); - stat.close(); - } - - private static void insertDataToSQLTable() throws SQLException { - StringBuilder sqlQueryBuilder = new StringBuilder("INSERT INTO books (id, title, author, price, qty) VALUES "); - sqlQueryBuilder.append("(1001, 'Java for dummies', 'Tan Ah Teck', 11.11, 11),"); - sqlQueryBuilder.append("(1002, 'More Java for dummies', 'Tan Ah Teck', 22.22, 22),"); - sqlQueryBuilder.append("(1003, 'More Java for more dummies', 'Mohammad Ali', 33.33, 33),"); - sqlQueryBuilder.append("(1004, 'A Cup of Java', 'Kumar', 44.44, 44),"); - sqlQueryBuilder.append("(1005, 'A Teaspoon of Java', 'Kevin Jones', 55.55, 55)"); - - Statement stat = conn.createStatement(); - stat.execute(sqlQueryBuilder.toString()); - stat.close(); - } - - @AfterClass - public static void tearDownClass() { - cleanUpDerbyDatabases(); - } - - private static void cleanUpDerbyDatabases() { - try { - String dbURL = "jdbc:derby:memory:ebookshop;create=true"; - Class.forName("org.apache.derby.jdbc.EmbeddedDriver"); - - conn = DriverManager.getConnection(dbURL); - Statement stat = conn.createStatement(); - stat.executeUpdate("DROP TABLE books"); - stat.close(); - conn.close(); - } catch (Exception e) { - e.printStackTrace(); - Assert.fail(); - } - } - - @After - public void tearDown() { - jdbcInputFormat = null; - } - - @Test(expected = IllegalArgumentException.class) - public void testInvalidDriver() throws IOException { - jdbcInputFormat = JDBCInputFormat.buildJDBCInputFormat() - .setDrivername("org.apache.derby.jdbc.idontexist") - .setDBUrl("jdbc:derby:memory:ebookshop") - .setQuery("select * from books") - .finish(); - jdbcInputFormat.open(null); - } - - @Test(expected = IllegalArgumentException.class) - public void testInvalidURL() throws IOException { - jdbcInputFormat = JDBCInputFormat.buildJDBCInputFormat() - .setDrivername("org.apache.derby.jdbc.EmbeddedDriver") - .setDBUrl("jdbc:der:iamanerror:mory:ebookshop") - .setQuery("select * from books") - .finish(); - jdbcInputFormat.open(null); - } - - @Test(expected = IllegalArgumentException.class) - public void testInvalidQuery() throws IOException { - jdbcInputFormat = JDBCInputFormat.buildJDBCInputFormat() - .setDrivername("org.apache.derby.jdbc.EmbeddedDriver") - .setDBUrl("jdbc:derby:memory:ebookshop") - .setQuery("iamnotsql") - .finish(); - jdbcInputFormat.open(null); - } - - @Test(expected = IllegalArgumentException.class) - public void testIncompleteConfiguration() throws IOException { - jdbcInputFormat = JDBCInputFormat.buildJDBCInputFormat() - .setDrivername("org.apache.derby.jdbc.EmbeddedDriver") - .setQuery("select * from books") - .finish(); - } - - @Test(expected = IOException.class) - public void testIncompatibleTuple() throws IOException { - jdbcInputFormat = JDBCInputFormat.buildJDBCInputFormat() - .setDrivername("org.apache.derby.jdbc.EmbeddedDriver") - .setDBUrl("jdbc:derby:memory:ebookshop") - .setQuery("select * from books") - .finish(); - jdbcInputFormat.open(null); - jdbcInputFormat.nextRecord(new Tuple2()); - } - - @Test - public void testJDBCInputFormat() throws IOException { - jdbcInputFormat = JDBCInputFormat.buildJDBCInputFormat() - .setDrivername("org.apache.derby.jdbc.EmbeddedDriver") - .setDBUrl("jdbc:derby:memory:ebookshop") - .setQuery("select * from books") - .finish(); - jdbcInputFormat.open(null); - Tuple5 tuple = new Tuple5(); - int recordCount = 0; - while (!jdbcInputFormat.reachedEnd()) { - jdbcInputFormat.nextRecord(tuple); - Assert.assertEquals("Field 0 should be int", Integer.class, tuple.getField(0).getClass()); - Assert.assertEquals("Field 1 should be String", String.class, tuple.getField(1).getClass()); - Assert.assertEquals("Field 2 should be String", String.class, tuple.getField(2).getClass()); - Assert.assertEquals("Field 3 should be float", Double.class, tuple.getField(3).getClass()); - Assert.assertEquals("Field 4 should be int", Integer.class, tuple.getField(4).getClass()); - - for (int x = 0; x < 5; x++) { - Assert.assertEquals(dbData[recordCount][x], tuple.getField(x)); - } - recordCount++; - } - Assert.assertEquals(5, recordCount); - } - -}
http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCOutputFormatTest.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCOutputFormatTest.java b/flink-staging/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCOutputFormatTest.java deleted file mode 100644 index 7d004f9..0000000 --- a/flink-staging/flink-jdbc/src/test/java/org/apache/flink/api/java/io/jdbc/JDBCOutputFormatTest.java +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.api.java.io.jdbc; - -import java.io.IOException; -import java.sql.Connection; -import java.sql.DriverManager; -import java.sql.SQLException; -import java.sql.Statement; - -import org.junit.Assert; - -import org.apache.flink.api.java.tuple.Tuple5; -import org.junit.After; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -public class JDBCOutputFormatTest { - private JDBCInputFormat jdbcInputFormat; - private JDBCOutputFormat jdbcOutputFormat; - - private static Connection conn; - - static final Object[][] dbData = { - {1001, ("Java for dummies"), ("Tan Ah Teck"), 11.11, 11}, - {1002, ("More Java for dummies"), ("Tan Ah Teck"), 22.22, 22}, - {1003, ("More Java for more dummies"), ("Mohammad Ali"), 33.33, 33}, - {1004, ("A Cup of Java"), ("Kumar"), 44.44, 44}, - {1005, ("A Teaspoon of Java"), ("Kevin Jones"), 55.55, 55}}; - - @BeforeClass - public static void setUpClass() throws SQLException { - try { - System.setProperty("derby.stream.error.field", "org.apache.flink.api.java.record.io.jdbc.DevNullLogStream.DEV_NULL"); - prepareDerbyDatabase(); - } catch (ClassNotFoundException e) { - e.printStackTrace(); - Assert.fail(); - } - } - - private static void prepareDerbyDatabase() throws ClassNotFoundException, SQLException { - String dbURL = "jdbc:derby:memory:ebookshop;create=true"; - Class.forName("org.apache.derby.jdbc.EmbeddedDriver"); - conn = DriverManager.getConnection(dbURL); - createTable("books"); - createTable("newbooks"); - insertDataToSQLTables(); - conn.close(); - } - - private static void createTable(String tableName) throws SQLException { - StringBuilder sqlQueryBuilder = new StringBuilder("CREATE TABLE "); - sqlQueryBuilder.append(tableName); - sqlQueryBuilder.append(" ("); - sqlQueryBuilder.append("id INT NOT NULL DEFAULT 0,"); - sqlQueryBuilder.append("title VARCHAR(50) DEFAULT NULL,"); - sqlQueryBuilder.append("author VARCHAR(50) DEFAULT NULL,"); - sqlQueryBuilder.append("price FLOAT DEFAULT NULL,"); - sqlQueryBuilder.append("qty INT DEFAULT NULL,"); - sqlQueryBuilder.append("PRIMARY KEY (id))"); - - Statement stat = conn.createStatement(); - stat.executeUpdate(sqlQueryBuilder.toString()); - stat.close(); - } - - private static void insertDataToSQLTables() throws SQLException { - StringBuilder sqlQueryBuilder = new StringBuilder("INSERT INTO books (id, title, author, price, qty) VALUES "); - sqlQueryBuilder.append("(1001, 'Java for dummies', 'Tan Ah Teck', 11.11, 11),"); - sqlQueryBuilder.append("(1002, 'More Java for dummies', 'Tan Ah Teck', 22.22, 22),"); - sqlQueryBuilder.append("(1003, 'More Java for more dummies', 'Mohammad Ali', 33.33, 33),"); - sqlQueryBuilder.append("(1004, 'A Cup of Java', 'Kumar', 44.44, 44),"); - sqlQueryBuilder.append("(1005, 'A Teaspoon of Java', 'Kevin Jones', 55.55, 55)"); - - Statement stat = conn.createStatement(); - stat.execute(sqlQueryBuilder.toString()); - stat.close(); - } - - @AfterClass - public static void tearDownClass() { - cleanUpDerbyDatabases(); - } - - private static void cleanUpDerbyDatabases() { - try { - String dbURL = "jdbc:derby:memory:ebookshop;create=true"; - Class.forName("org.apache.derby.jdbc.EmbeddedDriver"); - - conn = DriverManager.getConnection(dbURL); - Statement stat = conn.createStatement(); - stat.executeUpdate("DROP TABLE books"); - stat.executeUpdate("DROP TABLE newbooks"); - stat.close(); - conn.close(); - } catch (Exception e) { - e.printStackTrace(); - Assert.fail(); - } - } - - @After - public void tearDown() { - jdbcOutputFormat = null; - } - - @Test(expected = IllegalArgumentException.class) - public void testInvalidDriver() throws IOException { - jdbcOutputFormat = JDBCOutputFormat.buildJDBCOutputFormat() - .setDrivername("org.apache.derby.jdbc.idontexist") - .setDBUrl("jdbc:derby:memory:ebookshop") - .setQuery("insert into books (id, title, author, price, qty) values (?,?,?,?,?)") - .finish(); - jdbcOutputFormat.open(0, 1); - } - - @Test(expected = IllegalArgumentException.class) - public void testInvalidURL() throws IOException { - jdbcOutputFormat = JDBCOutputFormat.buildJDBCOutputFormat() - .setDrivername("org.apache.derby.jdbc.EmbeddedDriver") - .setDBUrl("jdbc:der:iamanerror:mory:ebookshop") - .setQuery("insert into books (id, title, author, price, qty) values (?,?,?,?,?)") - .finish(); - jdbcOutputFormat.open(0, 1); - } - - @Test(expected = IllegalArgumentException.class) - public void testInvalidQuery() throws IOException { - jdbcOutputFormat = JDBCOutputFormat.buildJDBCOutputFormat() - .setDrivername("org.apache.derby.jdbc.EmbeddedDriver") - .setDBUrl("jdbc:derby:memory:ebookshop") - .setQuery("iamnotsql") - .finish(); - jdbcOutputFormat.open(0, 1); - } - - @Test(expected = IllegalArgumentException.class) - public void testIncompleteConfiguration() throws IOException { - jdbcOutputFormat = JDBCOutputFormat.buildJDBCOutputFormat() - .setDrivername("org.apache.derby.jdbc.EmbeddedDriver") - .setQuery("insert into books (id, title, author, price, qty) values (?,?,?,?,?)") - .finish(); - } - - - @Test(expected = IllegalArgumentException.class) - public void testIncompatibleTypes() throws IOException { - jdbcOutputFormat = JDBCOutputFormat.buildJDBCOutputFormat() - .setDrivername("org.apache.derby.jdbc.EmbeddedDriver") - .setDBUrl("jdbc:derby:memory:ebookshop") - .setQuery("insert into books (id, title, author, price, qty) values (?,?,?,?,?)") - .finish(); - jdbcOutputFormat.open(0, 1); - - Tuple5 tuple5 = new Tuple5(); - tuple5.setField(4, 0); - tuple5.setField("hello", 1); - tuple5.setField("world", 2); - tuple5.setField(0.99, 3); - tuple5.setField("imthewrongtype", 4); - - jdbcOutputFormat.writeRecord(tuple5); - jdbcOutputFormat.close(); - } - - @Test - public void testJDBCOutputFormat() throws IOException { - String sourceTable = "books"; - String targetTable = "newbooks"; - String driverPath = "org.apache.derby.jdbc.EmbeddedDriver"; - String dbUrl = "jdbc:derby:memory:ebookshop"; - - jdbcOutputFormat = JDBCOutputFormat.buildJDBCOutputFormat() - .setDBUrl(dbUrl) - .setDrivername(driverPath) - .setQuery("insert into " + targetTable + " (id, title, author, price, qty) values (?,?,?,?,?)") - .finish(); - jdbcOutputFormat.open(0, 1); - - jdbcInputFormat = JDBCInputFormat.buildJDBCInputFormat() - .setDrivername(driverPath) - .setDBUrl(dbUrl) - .setQuery("select * from " + sourceTable) - .finish(); - jdbcInputFormat.open(null); - - Tuple5 tuple = new Tuple5(); - while (!jdbcInputFormat.reachedEnd()) { - jdbcInputFormat.nextRecord(tuple); - jdbcOutputFormat.writeRecord(tuple); - } - - jdbcOutputFormat.close(); - jdbcInputFormat.close(); - - jdbcInputFormat = JDBCInputFormat.buildJDBCInputFormat() - .setDrivername(driverPath) - .setDBUrl(dbUrl) - .setQuery("select * from " + targetTable) - .finish(); - jdbcInputFormat.open(null); - - int recordCount = 0; - while (!jdbcInputFormat.reachedEnd()) { - jdbcInputFormat.nextRecord(tuple); - Assert.assertEquals("Field 0 should be int", Integer.class, tuple.getField(0).getClass()); - Assert.assertEquals("Field 1 should be String", String.class, tuple.getField(1).getClass()); - Assert.assertEquals("Field 2 should be String", String.class, tuple.getField(2).getClass()); - Assert.assertEquals("Field 3 should be float", Double.class, tuple.getField(3).getClass()); - Assert.assertEquals("Field 4 should be int", Integer.class, tuple.getField(4).getClass()); - - for (int x = 0; x < 5; x++) { - Assert.assertEquals(dbData[recordCount][x], tuple.getField(x)); - } - - recordCount++; - } - Assert.assertEquals(5, recordCount); - - jdbcInputFormat.close(); - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-jdbc/src/test/resources/log4j-test.properties ---------------------------------------------------------------------- diff --git a/flink-staging/flink-jdbc/src/test/resources/log4j-test.properties b/flink-staging/flink-jdbc/src/test/resources/log4j-test.properties deleted file mode 100644 index 2fb9345..0000000 --- a/flink-staging/flink-jdbc/src/test/resources/log4j-test.properties +++ /dev/null @@ -1,19 +0,0 @@ -################################################################################ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -################################################################################ - -log4j.rootLogger=OFF \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-jdbc/src/test/resources/logback-test.xml ---------------------------------------------------------------------- diff --git a/flink-staging/flink-jdbc/src/test/resources/logback-test.xml b/flink-staging/flink-jdbc/src/test/resources/logback-test.xml deleted file mode 100644 index 8b3bb27..0000000 --- a/flink-staging/flink-jdbc/src/test/resources/logback-test.xml +++ /dev/null @@ -1,29 +0,0 @@ -<!-- - ~ Licensed to the Apache Software Foundation (ASF) under one - ~ or more contributor license agreements. See the NOTICE file - ~ distributed with this work for additional information - ~ regarding copyright ownership. The ASF licenses this file - ~ to you under the Apache License, Version 2.0 (the - ~ "License"); you may not use this file except in compliance - ~ with the License. You may obtain a copy of the License at - ~ - ~ http://www.apache.org/licenses/LICENSE-2.0 - ~ - ~ Unless required by applicable law or agreed to in writing, software - ~ distributed under the License is distributed on an "AS IS" BASIS, - ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - ~ See the License for the specific language governing permissions and - ~ limitations under the License. - --> - -<configuration> - <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> - <encoder> - <pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{60} %X{sourceThread} - %msg%n</pattern> - </encoder> - </appender> - - <root level="WARN"> - <appender-ref ref="STDOUT"/> - </root> -</configuration> \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/README.md ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/README.md b/flink-staging/flink-ml/README.md deleted file mode 100644 index 5cabd7c..0000000 --- a/flink-staging/flink-ml/README.md +++ /dev/null @@ -1,22 +0,0 @@ -Flink-ML constitutes the machine learning library of Apache Flink. -Our vision is to make machine learning easily accessible to a wide audience and yet to achieve extraordinary performance. -For this purpose, Flink-ML is based on two pillars: - -Flink-ML contains implementations of popular ML algorithms which are highly optimized for Apache Flink. -Theses implementations allow to scale to data sizes which vastly exceed the memory of a single computer. -Flink-ML currently comprises the following algorithms: - -* Classification -** Soft-margin SVM -* Regression -** Multiple linear regression -* Recommendation -** Alternating least squares (ALS) - -Since most of the work in data analytics is related to post- and pre-processing of data where the performance is not crucial, Flink wants to offer a simple abstraction to do that. -Linear algebra, as common ground of many ML algorithms, represents such a high-level abstraction. -Therefore, Flink will support the Mahout DSL as a execution engine and provide tools to neatly integrate the optimized algorithms into a linear algebra program. - -Flink-ML has just been recently started. -As part of Apache Flink, it heavily relies on the active work and contributions of its community and others. -Thus, if you want to add a new algorithm to the library, then find out [how to contribute]((http://flink.apache.org/how-to-contribute.html)) and open a pull request! \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/pom.xml ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/pom.xml b/flink-staging/flink-ml/pom.xml deleted file mode 100644 index 80c464c..0000000 --- a/flink-staging/flink-ml/pom.xml +++ /dev/null @@ -1,162 +0,0 @@ -<?xml version="1.0" encoding="UTF-8"?> -<!-- - ~ Licensed to the Apache Software Foundation (ASF) under one - ~ or more contributor license agreements. See the NOTICE file - ~ distributed with this work for additional information - ~ regarding copyright ownership. The ASF licenses this file - ~ to you under the Apache License, Version 2.0 (the - ~ "License"); you may not use this file except in compliance - ~ with the License. You may obtain a copy of the License at - ~ - ~ http://www.apache.org/licenses/LICENSE-2.0 - ~ - ~ Unless required by applicable law or agreed to in writing, software - ~ distributed under the License is distributed on an "AS IS" BASIS, - ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - ~ See the License for the specific language governing permissions and - ~ limitations under the License. - --> -<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> - - <modelVersion>4.0.0</modelVersion> - - <parent> - <groupId>org.apache.flink</groupId> - <artifactId>flink-staging</artifactId> - <version>1.0-SNAPSHOT</version> - <relativePath>..</relativePath> - </parent> - - <artifactId>flink-ml</artifactId> - <name>flink-ml</name> - - <packaging>jar</packaging> - - <dependencies> - <dependency> - <groupId>org.apache.flink</groupId> - <artifactId>flink-scala</artifactId> - <version>${project.version}</version> - </dependency> - - <dependency> - <groupId>org.scalanlp</groupId> - <artifactId>breeze_${scala.binary.version}</artifactId> - <version>0.11.2</version> - </dependency> - - <dependency> - <groupId>org.apache.flink</groupId> - <artifactId>flink-clients</artifactId> - <version>${project.version}</version> - <scope>test</scope> - </dependency> - - <dependency> - <groupId>org.apache.flink</groupId> - <artifactId>flink-clients</artifactId> - <version>${project.version}</version> - <type>test-jar</type> - <scope>test</scope> - </dependency> - - <dependency> - <groupId>org.apache.flink</groupId> - <artifactId>flink-core</artifactId> - <version>${project.version}</version> - <type>test-jar</type> - <scope>test</scope> - </dependency> - - <dependency> - <groupId>org.apache.flink</groupId> - <artifactId>flink-test-utils</artifactId> - <version>${project.version}</version> - <scope>test</scope> - </dependency> - </dependencies> - - <build> - <plugins> - <plugin> - <groupId>org.scala-tools</groupId> - <artifactId>maven-scala-plugin</artifactId> - <version>2.15.2</version> - <executions> - <execution> - <goals> - <goal>compile</goal> - <goal>testCompile</goal> - </goals> - </execution> - </executions> - <configuration> - <sourceDir>src/main/scala</sourceDir> - <testSourceDir>src/test/scala</testSourceDir> - <jvmArgs> - <jvmArg>-Xms64m</jvmArg> - <jvmArg>-Xmx1024m</jvmArg> - </jvmArgs> - </configuration> - </plugin> - - <plugin> - <groupId>org.scalatest</groupId> - <artifactId>scalatest-maven-plugin</artifactId> - <version>1.0</version> - <configuration> - <reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory> - <stdout>W</stdout> <!-- Skip coloring output --> - </configuration> - <executions> - <execution> - <id>scala-test</id> - <goals> - <goal>test</goal> - </goals> - <configuration> - <suffixes>(?<!(IT|Integration))(Test|Suite|Case)</suffixes> - <argLine>-Xms256m -Xmx800m -Dlog4j.configuration=${log4j.configuration} -Dlog.dir=${log.dir} -Dmvn.forkNumber=1 -XX:-UseGCOverheadLimit</argLine> - </configuration> - </execution> - <execution> - <id>integration-test</id> - <phase>integration-test</phase> - <goals> - <goal>test</goal> - </goals> - <configuration> - <suffixes>(IT|Integration)(Test|Suite|Case)</suffixes> - <argLine>-Xms256m -Xmx800m -Dlog4j.configuration=${log4j.configuration} -Dlog.dir=${log.dir} -Dmvn.forkNumber=1 -XX:-UseGCOverheadLimit</argLine> - </configuration> - </execution> - </executions> - </plugin> - - <plugin> - <groupId>org.scalastyle</groupId> - <artifactId>scalastyle-maven-plugin</artifactId> - <version>0.5.0</version> - <executions> - <execution> - <goals> - <goal>check</goal> - </goals> - </execution> - </executions> - <configuration> - <verbose>false</verbose> - <failOnViolation>true</failOnViolation> - <includeTestSourceDirectory>true</includeTestSourceDirectory> - <failOnWarning>false</failOnWarning> - <sourceDirectory>${basedir}/src/main/scala</sourceDirectory> - <testSourceDirectory>${basedir}/src/test/scala</testSourceDirectory> - <configLocation>${project.basedir}/../../tools/maven/scalastyle-config.xml</configLocation> - <outputFile>${project.basedir}/scalastyle-output.xml</outputFile> - <outputEncoding>UTF-8</outputEncoding> - </configuration> - </plugin> - </plugins> - </build> -</project> http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala deleted file mode 100644 index 804ab5f..0000000 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml - -import org.apache.flink.api.common.functions.RichMapFunction -import org.apache.flink.api.java.operators.DataSink -import org.apache.flink.api.scala._ -import org.apache.flink.configuration.Configuration -import org.apache.flink.ml.common.LabeledVector -import org.apache.flink.ml.math.SparseVector - -/** Convenience functions for machine learning tasks - * - * This object contains convenience functions for machine learning tasks: - * - * - readLibSVM: - * Reads a libSVM/SVMLight input file and returns a data set of [[LabeledVector]]. - * The file format is specified [http://svmlight.joachims.org/ here]. - * - * - writeLibSVM: - * Writes a data set of [[LabeledVector]] in libSVM/SVMLight format to disk. THe file format - * is specified [http://svmlight.joachims.org/ here]. - */ -object MLUtils { - - val DIMENSION = "dimension" - - /** Reads a file in libSVM/SVMLight format and converts the data into a data set of - * [[LabeledVector]]. The dimension of the [[LabeledVector]] is determined automatically. - * - * Since the libSVM/SVMLight format stores a vector in its sparse form, the [[LabeledVector]] - * will also be instantiated with a [[SparseVector]]. - * - * @param env executionEnvironment [[ExecutionEnvironment]] - * @param filePath Path to the input file - * @return [[DataSet]] of [[LabeledVector]] containing the information of the libSVM/SVMLight - * file - */ - def readLibSVM(env: ExecutionEnvironment, filePath: String): DataSet[LabeledVector] = { - val labelCOODS = env.readTextFile(filePath).flatMap { - line => - // remove all comments which start with a '#' - val commentFreeLine = line.takeWhile(_ != '#').trim - - if(commentFreeLine.nonEmpty) { - val splits = commentFreeLine.split(' ') - val label = splits.head.toDouble - val sparseFeatures = splits.tail - val coos = sparseFeatures.map { - str => - val pair = str.split(':') - require(pair.length == 2, "Each feature entry has to have the form <feature>:<value>") - - // libSVM index is 1-based, but we expect it to be 0-based - val index = pair(0).toInt - 1 - val value = pair(1).toDouble - - (index, value) - } - - Some((label, coos)) - } else { - None - } - } - - // Calculate maximum dimension of vectors - val dimensionDS = labelCOODS.map { - labelCOO => - labelCOO._2.map( _._1 + 1 ).max - }.reduce(scala.math.max(_, _)) - - labelCOODS.map{ new RichMapFunction[(Double, Array[(Int, Double)]), LabeledVector] { - var dimension = 0 - - override def open(configuration: Configuration): Unit = { - dimension = getRuntimeContext.getBroadcastVariable(DIMENSION).get(0) - } - - override def map(value: (Double, Array[(Int, Double)])): LabeledVector = { - new LabeledVector(value._1, SparseVector.fromCOO(dimension, value._2)) - } - }}.withBroadcastSet(dimensionDS, DIMENSION) - } - - /** Writes a [[DataSet]] of [[LabeledVector]] to a file using the libSVM/SVMLight format. - * - * @param filePath Path to output file - * @param labeledVectors [[DataSet]] of [[LabeledVector]] to write to disk - * @return - */ - def writeLibSVM(filePath: String, labeledVectors: DataSet[LabeledVector]): DataSink[String] = { - val stringRepresentation = labeledVectors.map{ - labeledVector => - val vectorStr = labeledVector.vector. - // remove zero entries - filter( _._2 != 0). - map{case (idx, value) => (idx + 1) + ":" + value}. - mkString(" ") - - labeledVector.label + " " + vectorStr - } - - stringRepresentation.writeAsText(filePath) - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala deleted file mode 100644 index 4a780e9..0000000 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala +++ /dev/null @@ -1,550 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.classification - -import org.apache.flink.api.common.functions.RichMapFunction -import org.apache.flink.api.scala._ -import org.apache.flink.configuration.Configuration -import org.apache.flink.ml.common.FlinkMLTools.ModuloKeyPartitioner -import org.apache.flink.ml.common._ -import org.apache.flink.ml.math.Breeze._ -import org.apache.flink.ml.math.{DenseVector, Vector} -import org.apache.flink.ml.pipeline.{FitOperation, PredictOperation, Predictor} - -import scala.collection.mutable.ArrayBuffer -import scala.util.Random - -import breeze.linalg.{DenseVector => BreezeDenseVector, Vector => BreezeVector} - -/** Implements a soft-margin SVM using the communication-efficient distributed dual coordinate - * ascent algorithm (CoCoA) with hinge-loss function. - * - * It can be used for binary classification problems, with the labels set as +1.0 to indiciate a - * positive example and -1.0 to indicate a negative example. - * - * The algorithm solves the following minimization problem: - * - * `min_{w in bbb"R"^d} lambda/2 ||w||^2 + 1/n sum_(i=1)^n l_{i}(w^Tx_i)` - * - * with `w` being the weight vector, `lambda` being the regularization constant, - * `x_{i} in bbb"R"^d` being the data points and `l_{i}` being the convex loss functions, which - * can also depend on the labels `y_{i} in bbb"R"`. - * In the current implementation the regularizer is the 2-norm and the loss functions are the - * hinge-loss functions: - * - * `l_{i} = max(0, 1 - y_{i} * w^Tx_i` - * - * With these choices, the problem definition is equivalent to a SVM with soft-margin. - * Thus, the algorithm allows us to train a SVM with soft-margin. - * - * The minimization problem is solved by applying stochastic dual coordinate ascent (SDCA). - * In order to make the algorithm efficient in a distributed setting, the CoCoA algorithm - * calculates several iterations of SDCA locally on a data block before merging the local - * updates into a valid global state. - * This state is redistributed to the different data partitions where the next round of local - * SDCA iterations is then executed. - * The number of outer iterations and local SDCA iterations control the overall network costs, - * because there is only network communication required for each outer iteration. - * The local SDCA iterations are embarrassingly parallel once the individual data partitions have - * been distributed across the cluster. - * - * Further details of the algorithm can be found [[http://arxiv.org/abs/1409.1458 here]]. - * - * @example - * {{{ - * val trainingDS: DataSet[LabeledVector] = env.readLibSVM(pathToTrainingFile) - * - * val svm = SVM() - * .setBlocks(10) - * - * svm.fit(trainingDS) - * - * val testingDS: DataSet[Vector] = env.readLibSVM(pathToTestingFile) - * .map(lv => lv.vector) - * - * val predictionDS: DataSet[(Vector, Double)] = svm.predict(testingDS) - * }}} - * - * =Parameters= - * - * - [[org.apache.flink.ml.classification.SVM.Blocks]]: - * Sets the number of blocks into which the input data will be split. On each block the local - * stochastic dual coordinate ascent method is executed. This number should be set at least to - * the degree of parallelism. If no value is specified, then the parallelism of the input - * [[DataSet]] is used as the number of blocks. (Default value: '''None''') - * - * - [[org.apache.flink.ml.classification.SVM.Iterations]]: - * Defines the maximum number of iterations of the outer loop method. In other words, it defines - * how often the SDCA method is applied to the blocked data. After each iteration, the locally - * computed weight vector updates have to be reduced to update the global weight vector value. - * The new weight vector is broadcast to all SDCA tasks at the beginning of each iteration. - * (Default value: '''10''') - * - * - [[org.apache.flink.ml.classification.SVM.LocalIterations]]: - * Defines the maximum number of SDCA iterations. In other words, it defines how many data points - * are drawn from each local data block to calculate the stochastic dual coordinate ascent. - * (Default value: '''10''') - * - * - [[org.apache.flink.ml.classification.SVM.Regularization]]: - * Defines the regularization constant of the SVM algorithm. The higher the value, the smaller - * will the 2-norm of the weight vector be. In case of a SVM with hinge loss this means that the - * SVM margin will be wider even though it might contain some false classifications. - * (Default value: '''1.0''') - * - * - [[org.apache.flink.ml.classification.SVM.Stepsize]]: - * Defines the initial step size for the updates of the weight vector. The larger the step size - * is, the larger will be the contribution of the weight vector updates to the next weight vector - * value. The effective scaling of the updates is `stepsize/blocks`. This value has to be tuned - * in case that the algorithm becomes instable. (Default value: '''1.0''') - * - * - [[org.apache.flink.ml.classification.SVM.Seed]]: - * Defines the seed to initialize the random number generator. The seed directly controls which - * data points are chosen for the SDCA method. (Default value: '''0''') - * - * - [[org.apache.flink.ml.classification.SVM.ThresholdValue]]: - * Defines the limiting value for the decision function above which examples are labeled as - * positive (+1.0). Examples with a decision function value below this value are classified as - * negative(-1.0). In order to get the raw decision function values you need to indicate it by - * using the [[org.apache.flink.ml.classification.SVM.OutputDecisionFunction]]. - * (Default value: '''0.0''') - * - * - [[org.apache.flink.ml.classification.SVM.OutputDecisionFunction]]: - * Determines whether the predict and evaluate functions of the SVM should return the distance - * to the separating hyperplane, or binary class labels. Setting this to true will return the raw - * distance to the hyperplane for each example. Setting it to false will return the binary - * class label (+1.0, -1.0) (Default value: '''false''') - */ -class SVM extends Predictor[SVM] { - - import SVM._ - - /** Stores the learned weight vector after the fit operation */ - var weightsOption: Option[DataSet[DenseVector]] = None - - /** Sets the number of data blocks/partitions - * - * @param blocks - * @return itself - */ - def setBlocks(blocks: Int): SVM = { - parameters.add(Blocks, blocks) - this - } - - /** Sets the number of outer iterations - * - * @param iterations - * @return itself - */ - def setIterations(iterations: Int): SVM = { - parameters.add(Iterations, iterations) - this - } - - /** Sets the number of local SDCA iterations - * - * @param localIterations - * @return itselft - */ - def setLocalIterations(localIterations: Int): SVM = { - parameters.add(LocalIterations, localIterations) - this - } - - /** Sets the regularization constant - * - * @param regularization - * @return itself - */ - def setRegularization(regularization: Double): SVM = { - parameters.add(Regularization, regularization) - this - } - - /** Sets the stepsize for the weight vector updates - * - * @param stepsize - * @return itself - */ - def setStepsize(stepsize: Double): SVM = { - parameters.add(Stepsize, stepsize) - this - } - - /** Sets the seed value for the random number generator - * - * @param seed - * @return itself - */ - def setSeed(seed: Long): SVM = { - parameters.add(Seed, seed) - this - } - - /** Sets the threshold above which elements are classified as positive. - * - * The [[predict ]] and [[evaluate]] functions will return +1.0 for items with a decision - * function value above this threshold, and -1.0 for items below it. - * @param threshold - * @return - */ - def setThreshold(threshold: Double): SVM = { - parameters.add(ThresholdValue, threshold) - this - } - - /** Sets whether the predictions should return the raw decision function value or the - * thresholded binary value. - * - * When setting this to true, predict and evaluate return the raw decision value, which is - * the distance from the separating hyperplane. - * When setting this to false, they return thresholded (+1.0, -1.0) values. - * - * @param outputDecisionFunction When set to true, [[predict ]] and [[evaluate]] return the raw - * decision function values. When set to false, they return the - * thresholded binary values (+1.0, -1.0). - */ - def setOutputDecisionFunction(outputDecisionFunction: Boolean): SVM = { - parameters.add(OutputDecisionFunction, outputDecisionFunction) - this - } -} - -/** Companion object of SVM. Contains convenience functions and the parameter type definitions - * of the algorithm. - */ -object SVM{ - - val WEIGHT_VECTOR ="weightVector" - - // ========================================== Parameters ========================================= - - case object Blocks extends Parameter[Int] { - val defaultValue: Option[Int] = None - } - - case object Iterations extends Parameter[Int] { - val defaultValue = Some(10) - } - - case object LocalIterations extends Parameter[Int] { - val defaultValue = Some(10) - } - - case object Regularization extends Parameter[Double] { - val defaultValue = Some(1.0) - } - - case object Stepsize extends Parameter[Double] { - val defaultValue = Some(1.0) - } - - case object Seed extends Parameter[Long] { - val defaultValue = Some(Random.nextLong()) - } - - case object ThresholdValue extends Parameter[Double] { - val defaultValue = Some(0.0) - } - - case object OutputDecisionFunction extends Parameter[Boolean] { - val defaultValue = Some(false) - } - - // ========================================== Factory methods ==================================== - - def apply(): SVM = { - new SVM() - } - - // ========================================== Operations ========================================= - - /** Provides the operation that makes the predictions for individual examples. - * - * @tparam T - * @return A PredictOperation, through which it is possible to predict a value, given a - * feature vector - */ - implicit def predictVectors[T <: Vector] = { - new PredictOperation[SVM, DenseVector, T, Double](){ - - var thresholdValue: Double = _ - var outputDecisionFunction: Boolean = _ - - override def getModel(self: SVM, predictParameters: ParameterMap): DataSet[DenseVector] = { - thresholdValue = predictParameters(ThresholdValue) - outputDecisionFunction = predictParameters(OutputDecisionFunction) - self.weightsOption match { - case Some(model) => model - case None => { - throw new RuntimeException("The SVM model has not been trained. Call first fit" + - "before calling the predict operation.") - } - } - } - - override def predict(value: T, model: DenseVector): Double = { - val rawValue = value.asBreeze dot model.asBreeze - - if (outputDecisionFunction) { - rawValue - } else { - if (rawValue > thresholdValue) 1.0 else -1.0 - } - } - } - } - - /** [[FitOperation]] which trains a SVM with soft-margin based on the given training data set. - * - */ - implicit val fitSVM = { - new FitOperation[SVM, LabeledVector] { - override def fit( - instance: SVM, - fitParameters: ParameterMap, - input: DataSet[LabeledVector]) - : Unit = { - val resultingParameters = instance.parameters ++ fitParameters - - // Check if the number of blocks/partitions has been specified - val blocks = resultingParameters.get(Blocks) match { - case Some(value) => value - case None => input.getParallelism - } - - val scaling = resultingParameters(Stepsize)/blocks - val iterations = resultingParameters(Iterations) - val localIterations = resultingParameters(LocalIterations) - val regularization = resultingParameters(Regularization) - val seed = resultingParameters(Seed) - - // Obtain DataSet with the dimension of the data points - val dimension = input.map{_.vector.size}.reduce{ - (a, b) => { - require(a == b, "Dimensions of feature vectors have to be equal.") - a - } - } - - val initialWeights = createInitialWeights(dimension) - - // Count the number of vectors, but keep the value in a DataSet to broadcast it later - // TODO: Once efficient count and intermediate result partitions are implemented, use count - val numberVectors = input map { x => 1 } reduce { _ + _ } - - // Group the input data into blocks in round robin fashion - val blockedInputNumberElements = FlinkMLTools.block( - input, - blocks, - Some(ModuloKeyPartitioner)). - cross(numberVectors). - map { x => x } - - val resultingWeights = initialWeights.iterate(iterations) { - weights => { - // compute the local SDCA to obtain the weight vector updates - val deltaWs = localDualMethod( - weights, - blockedInputNumberElements, - localIterations, - regularization, - scaling, - seed - ) - - // scale the weight vectors - val weightedDeltaWs = deltaWs map { - deltaW => { - deltaW :*= scaling - } - } - - // calculate the new weight vector by adding the weight vector updates to the weight - // vector value - weights.union(weightedDeltaWs).reduce { _ + _ } - } - } - - // Store the learned weight vector in hte given instance - instance.weightsOption = Some(resultingWeights.map(_.fromBreeze[DenseVector])) - } - } - } - - /** Creates a zero vector of length dimension - * - * @param dimension [[DataSet]] containing the dimension of the initial weight vector - * @return Zero vector of length dimension - */ - private def createInitialWeights(dimension: DataSet[Int]): DataSet[BreezeDenseVector[Double]] = { - dimension.map { - d => BreezeDenseVector.zeros[Double](d) - } - } - - /** Computes the local SDCA on the individual data blocks/partitions - * - * @param w Current weight vector - * @param blockedInputNumberElements Blocked/Partitioned input data - * @param localIterations Number of local SDCA iterations - * @param regularization Regularization constant - * @param scaling Scaling value for new weight vector updates - * @param seed Random number generator seed - * @return [[DataSet]] of weight vector updates. The weight vector updates are double arrays - */ - private def localDualMethod( - w: DataSet[BreezeDenseVector[Double]], - blockedInputNumberElements: DataSet[(Block[LabeledVector], Int)], - localIterations: Int, - regularization: Double, - scaling: Double, - seed: Long) - : DataSet[BreezeDenseVector[Double]] = { - /* - Rich mapper calculating for each data block the local SDCA. We use a RichMapFunction here, - because we broadcast the current value of the weight vector to all mappers. - */ - val localSDCA = new RichMapFunction[(Block[LabeledVector], Int), BreezeDenseVector[Double]] { - var originalW: BreezeDenseVector[Double] = _ - // we keep the alphas across the outer loop iterations - val alphasArray = ArrayBuffer[BreezeDenseVector[Double]]() - // there might be several data blocks in one Flink partition, therefore store mapping - val idMapping = scala.collection.mutable.HashMap[Int, Int]() - var counter = 0 - - var r: Random = _ - - override def open(parameters: Configuration): Unit = { - originalW = getRuntimeContext.getBroadcastVariable(WEIGHT_VECTOR).get(0) - - if(r == null){ - r = new Random(seed ^ getRuntimeContext.getIndexOfThisSubtask) - } - } - - override def map(blockNumberElements: (Block[LabeledVector], Int)) - : BreezeDenseVector[Double] = { - val (block, numberElements) = blockNumberElements - - // check if we already processed a data block with the corresponding block index - val localIndex = idMapping.get(block.index) match { - case Some(idx) => idx - case None => - idMapping += (block.index -> counter) - counter += 1 - - alphasArray += BreezeDenseVector.zeros[Double](block.values.length) - - counter - 1 - } - - // create temporary alpha array for the local SDCA iterations - val tempAlphas = alphasArray(localIndex).copy - - val numLocalDatapoints = tempAlphas.length - val deltaAlphas = BreezeDenseVector.zeros[Double](numLocalDatapoints) - - val w = originalW.copy - - val deltaW = BreezeDenseVector.zeros[Double](originalW.length) - - for(i <- 1 to localIterations) { - // pick random data point for SDCA - val idx = r.nextInt(numLocalDatapoints) - - val LabeledVector(label, vector) = block.values(idx) - val alpha = tempAlphas(idx) - - // maximize the dual problem and retrieve alpha and weight vector updates - val (deltaAlpha, deltaWUpdate) = maximize( - vector.asBreeze, - label, - regularization, - alpha, - w, - numberElements) - - // update alpha values - tempAlphas(idx) += deltaAlpha - deltaAlphas(idx) += deltaAlpha - - // deltaWUpdate is already scaled with 1/lambda/n - w += deltaWUpdate - deltaW += deltaWUpdate - } - - // update local alpha values - alphasArray(localIndex) += deltaAlphas * scaling - - deltaW - } - } - - blockedInputNumberElements.map(localSDCA).withBroadcastSet(w, WEIGHT_VECTOR) - } - - /** Maximizes the dual problem using hinge loss functions. It returns the alpha and weight - * vector updates. - * - * @param x Selected data point - * @param y Label of selected data point - * @param regularization Regularization constant - * @param alpha Alpha value of selected data point - * @param w Current weight vector value - * @param numberElements Number of elements in the training data set - * @return Alpha and weight vector updates - */ - private def maximize( - x: BreezeVector[Double], - y: Double, regularization: Double, - alpha: Double, - w: BreezeVector[Double], - numberElements: Int) - : (Double, BreezeVector[Double]) = { - // compute hinge loss gradient - val dotProduct = x dot w - val grad = (y * dotProduct - 1.0) * (regularization * numberElements) - - // compute projected gradient - var proj_grad = if(alpha <= 0.0){ - scala.math.min(grad, 0) - } else if(alpha >= 1.0) { - scala.math.max(grad, 0) - } else { - grad - } - - if(scala.math.abs(grad) != 0.0){ - val qii = x dot x - val newAlpha = if(qii != 0.0){ - scala.math.min(scala.math.max(alpha - (grad / qii), 0.0), 1.0) - } else { - 1.0 - } - - val deltaW = x * y * (newAlpha - alpha) / (regularization * numberElements) - - (newAlpha - alpha, deltaW) - } else { - (0.0 , BreezeVector.zeros(w.length)) - } - } - -} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala deleted file mode 100644 index 1af77ea..0000000 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/Block.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common - -/** Base class for blocks of elements. - * - * TODO: Replace Vector type by Array type once Flink supports generic arrays - * - * @param index - * @param values - * @tparam T - */ -case class Block[T](index: Int, values: Vector[T]) {} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala deleted file mode 100644 index 553ec00..0000000 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala +++ /dev/null @@ -1,423 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common - -import org.apache.flink.api.common.functions.Partitioner -import org.apache.flink.api.common.io.FileOutputFormat.OutputDirectoryMode -import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.java.io.{TypeSerializerInputFormat, TypeSerializerOutputFormat} -import org.apache.flink.api.scala._ -import org.apache.flink.core.fs.FileSystem.WriteMode -import org.apache.flink.core.fs.Path - -import scala.reflect.ClassTag - -/** FlinkTools contains a set of convenience functions for Flink's machine learning library: - * - * - persist: - * Takes up to 5 [[DataSet]]s and file paths. Each [[DataSet]] is written to the specified - * path and subsequently re-read from disk. This method can be used to effectively split the - * execution graph at the given [[DataSet]]. Writing it to disk triggers its materialization - * and specifying it as a source will prevent the re-execution of it. - * - * - block: - * Takes a DataSet of elements T and groups them in n blocks. - * - */ -object FlinkMLTools { - - /** Registers the different FlinkML related types for Kryo serialization - * - * @param env - */ - def registerFlinkMLTypes(env: ExecutionEnvironment): Unit = { - - // Vector types - env.registerType(classOf[org.apache.flink.ml.math.DenseVector]) - env.registerType(classOf[org.apache.flink.ml.math.SparseVector]) - - // Matrix types - env.registerType(classOf[org.apache.flink.ml.math.DenseMatrix]) - env.registerType(classOf[org.apache.flink.ml.math.SparseMatrix]) - - // Breeze Vector types - env.registerType(classOf[breeze.linalg.DenseVector[_]]) - env.registerType(classOf[breeze.linalg.SparseVector[_]]) - - // Breeze specialized types - env.registerType(breeze.linalg.DenseVector.zeros[Double](0).getClass) - env.registerType(breeze.linalg.SparseVector.zeros[Double](0).getClass) - - // Breeze Matrix types - env.registerType(classOf[breeze.linalg.DenseMatrix[Double]]) - env.registerType(classOf[breeze.linalg.CSCMatrix[Double]]) - - // Breeze specialized types - env.registerType(breeze.linalg.DenseMatrix.zeros[Double](0, 0).getClass) - env.registerType(breeze.linalg.CSCMatrix.zeros[Double](0, 0).getClass) - } - - /** Writes a [[DataSet]] to the specified path and returns it as a DataSource for subsequent - * operations. - * - * @param dataset [[DataSet]] to write to disk - * @param path File path to write dataset to - * @tparam T Type of the [[DataSet]] elements - * @return [[DataSet]] reading the just written file - */ - def persist[T: ClassTag: TypeInformation](dataset: DataSet[T], path: String): DataSet[T] = { - val env = dataset.getExecutionEnvironment - val outputFormat = new TypeSerializerOutputFormat[T] - - val filePath = new Path(path) - - outputFormat.setOutputFilePath(filePath) - outputFormat.setWriteMode(WriteMode.OVERWRITE) - - dataset.output(outputFormat) - env.execute("FlinkTools persist") - - val inputFormat = new TypeSerializerInputFormat[T](dataset.getType) - inputFormat.setFilePath(filePath) - - env.createInput(inputFormat) - } - - /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for - * subsequent operations. - * - * @param ds1 First [[DataSet]] to write to disk - * @param ds2 Second [[DataSet]] to write to disk - * @param path1 Path for ds1 - * @param path2 Path for ds2 - * @tparam A Type of the first [[DataSet]]'s elements - * @tparam B Type of the second [[DataSet]]'s elements - * @return Tuple of [[DataSet]]s reading the just written files - */ - def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation](ds1: DataSet[A], ds2: - DataSet[B], path1: String, path2: String):(DataSet[A], DataSet[B]) = { - val env = ds1.getExecutionEnvironment - - val f1 = new Path(path1) - - val of1 = new TypeSerializerOutputFormat[A] - of1.setOutputFilePath(f1) - of1.setWriteMode(WriteMode.OVERWRITE) - - ds1.output(of1) - - val f2 = new Path(path2) - - val of2 = new TypeSerializerOutputFormat[B] - of2.setOutputFilePath(f2) - of2.setWriteMode(WriteMode.OVERWRITE) - - ds2.output(of2) - - env.execute("FlinkTools persist") - - val if1 = new TypeSerializerInputFormat[A](ds1.getType) - if1.setFilePath(f1) - - val if2 = new TypeSerializerInputFormat[B](ds2.getType) - if2.setFilePath(f2) - - (env.createInput(if1), env.createInput(if2)) - } - - /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for - * subsequent operations. - * - * @param ds1 First [[DataSet]] to write to disk - * @param ds2 Second [[DataSet]] to write to disk - * @param ds3 Third [[DataSet]] to write to disk - * @param path1 Path for ds1 - * @param path2 Path for ds2 - * @param path3 Path for ds3 - * @tparam A Type of first [[DataSet]]'s elements - * @tparam B Type of second [[DataSet]]'s elements - * @tparam C Type of third [[DataSet]]'s elements - * @return Tuple of [[DataSet]]s reading the just written files - */ - def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation, - C: ClassTag: TypeInformation](ds1: DataSet[A], ds2: DataSet[B], ds3: DataSet[C], path1: - String, path2: String, path3: String): (DataSet[A], DataSet[B], DataSet[C]) = { - val env = ds1.getExecutionEnvironment - - val f1 = new Path(path1) - - val of1 = new TypeSerializerOutputFormat[A] - of1.setOutputFilePath(f1) - of1.setWriteMode(WriteMode.OVERWRITE) - - ds1.output(of1) - - val f2 = new Path(path2) - - val of2 = new TypeSerializerOutputFormat[B] - of2.setOutputFilePath(f2) - of2.setWriteMode(WriteMode.OVERWRITE) - - ds2.output(of2) - - val f3 = new Path(path3) - - val of3 = new TypeSerializerOutputFormat[C] - of3.setOutputFilePath(f3) - of3.setWriteMode(WriteMode.OVERWRITE) - - ds3.output(of3) - - env.execute("FlinkTools persist") - - val if1 = new TypeSerializerInputFormat[A](ds1.getType) - if1.setFilePath(f1) - - val if2 = new TypeSerializerInputFormat[B](ds2.getType) - if2.setFilePath(f2) - - val if3 = new TypeSerializerInputFormat[C](ds3.getType) - if3.setFilePath(f3) - - (env.createInput(if1), env.createInput(if2), env.createInput(if3)) - } - - /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for - * subsequent operations. - * - * @param ds1 First [[DataSet]] to write to disk - * @param ds2 Second [[DataSet]] to write to disk - * @param ds3 Third [[DataSet]] to write to disk - * @param ds4 Fourth [[DataSet]] to write to disk - * @param path1 Path for ds1 - * @param path2 Path for ds2 - * @param path3 Path for ds3 - * @param path4 Path for ds4 - * @tparam A Type of first [[DataSet]]'s elements - * @tparam B Type of second [[DataSet]]'s elements - * @tparam C Type of third [[DataSet]]'s elements - * @tparam D Type of fourth [[DataSet]]'s elements - * @return Tuple of [[DataSet]]s reading the just written files - */ - def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation, - C: ClassTag: TypeInformation, D: ClassTag: TypeInformation](ds1: DataSet[A], ds2: DataSet[B], - ds3: DataSet[C], ds4: DataSet[D], - path1: String, path2: String, path3: - String, path4: String): - (DataSet[A], DataSet[B], DataSet[C], DataSet[D]) = { - val env = ds1.getExecutionEnvironment - - val f1 = new Path(path1) - - val of1 = new TypeSerializerOutputFormat[A] - of1.setOutputFilePath(f1) - of1.setWriteMode(WriteMode.OVERWRITE) - - ds1.output(of1) - - val f2 = new Path(path2) - - val of2 = new TypeSerializerOutputFormat[B] - of2.setOutputFilePath(f2) - of2.setWriteMode(WriteMode.OVERWRITE) - - ds2.output(of2) - - val f3 = new Path(path3) - - val of3 = new TypeSerializerOutputFormat[C] - of3.setOutputFilePath(f3) - of3.setWriteMode(WriteMode.OVERWRITE) - - ds3.output(of3) - - val f4 = new Path(path4) - - val of4 = new TypeSerializerOutputFormat[D] - of4.setOutputFilePath(f4) - of4.setWriteMode(WriteMode.OVERWRITE) - - ds4.output(of4) - - env.execute("FlinkTools persist") - - val if1 = new TypeSerializerInputFormat[A](ds1.getType) - if1.setFilePath(f1) - - val if2 = new TypeSerializerInputFormat[B](ds2.getType) - if2.setFilePath(f2) - - val if3 = new TypeSerializerInputFormat[C](ds3.getType) - if3.setFilePath(f3) - - val if4 = new TypeSerializerInputFormat[D](ds4.getType) - if4.setFilePath(f4) - - (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4)) - } - - /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for - * subsequent operations. - * - * @param ds1 First [[DataSet]] to write to disk - * @param ds2 Second [[DataSet]] to write to disk - * @param ds3 Third [[DataSet]] to write to disk - * @param ds4 Fourth [[DataSet]] to write to disk - * @param ds5 Fifth [[DataSet]] to write to disk - * @param path1 Path for ds1 - * @param path2 Path for ds2 - * @param path3 Path for ds3 - * @param path4 Path for ds4 - * @param path5 Path for ds5 - * @tparam A Type of first [[DataSet]]'s elements - * @tparam B Type of second [[DataSet]]'s elements - * @tparam C Type of third [[DataSet]]'s elements - * @tparam D Type of fourth [[DataSet]]'s elements - * @tparam E Type of fifth [[DataSet]]'s elements - * @return Tuple of [[DataSet]]s reading the just written files - */ - def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation, - C: ClassTag: TypeInformation, D: ClassTag: TypeInformation, E: ClassTag: TypeInformation] - (ds1: DataSet[A], ds2: DataSet[B], ds3: DataSet[C], ds4: DataSet[D], ds5: DataSet[E], path1: - String, path2: String, path3: String, path4: String, path5: String): (DataSet[A], DataSet[B], - DataSet[C], DataSet[D], DataSet[E]) = { - val env = ds1.getExecutionEnvironment - - val f1 = new Path(path1) - - val of1 = new TypeSerializerOutputFormat[A] - of1.setOutputFilePath(f1) - of1.setWriteMode(WriteMode.OVERWRITE) - - ds1.output(of1) - - val f2 = new Path(path2) - - val of2 = new TypeSerializerOutputFormat[B] - of2.setOutputFilePath(f2) - of2.setOutputDirectoryMode(OutputDirectoryMode.ALWAYS) - of2.setWriteMode(WriteMode.OVERWRITE) - - ds2.output(of2) - - val f3 = new Path(path3) - - val of3 = new TypeSerializerOutputFormat[C] - of3.setOutputFilePath(f3) - of3.setWriteMode(WriteMode.OVERWRITE) - - ds3.output(of3) - - val f4 = new Path(path4) - - val of4 = new TypeSerializerOutputFormat[D] - of4.setOutputFilePath(f4) - of4.setWriteMode(WriteMode.OVERWRITE) - - ds4.output(of4) - - val f5 = new Path(path5) - - val of5 = new TypeSerializerOutputFormat[E] - of5.setOutputFilePath(f5) - of5.setWriteMode(WriteMode.OVERWRITE) - - ds5.output(of5) - - env.execute("FlinkTools persist") - - val if1 = new TypeSerializerInputFormat[A](ds1.getType) - if1.setFilePath(f1) - - val if2 = new TypeSerializerInputFormat[B](ds2.getType) - if2.setFilePath(f2) - - val if3 = new TypeSerializerInputFormat[C](ds3.getType) - if3.setFilePath(f3) - - val if4 = new TypeSerializerInputFormat[D](ds4.getType) - if4.setFilePath(f4) - - val if5 = new TypeSerializerInputFormat[E](ds5.getType) - if5.setFilePath(f5) - - (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4), env - .createInput(if5)) - } - - /** Groups the DataSet input into numBlocks blocks. - * - * @param input - * @param numBlocks Number of Blocks - * @param partitionerOption Optional partitioner to control the partitioning - * @tparam T - * @return - */ - def block[T: TypeInformation: ClassTag]( - input: DataSet[T], - numBlocks: Int, - partitionerOption: Option[Partitioner[Int]] = None) - : DataSet[Block[T]] = { - val blockIDInput = input map { - element => - val blockID = element.hashCode() % numBlocks - - val blockIDResult = if(blockID < 0){ - blockID + numBlocks - } else { - blockID - } - - (blockIDResult, element) - } - - val preGroupBlockIDInput = partitionerOption match { - case Some(partitioner) => - blockIDInput partitionCustom(partitioner, 0) - - case None => blockIDInput - } - - preGroupBlockIDInput.groupBy(0).reduceGroup { - iter => { - val array = iter.toVector - - val blockID = array(0)._1 - val elements = array.map(_._2) - - Block[T](blockID, elements) - } - }.withForwardedFields("0 -> index") - } - - /** Distributes the elements by taking the modulo of their keys and assigning it to this channel - * - */ - object ModuloKeyPartitioner extends Partitioner[Int] { - override def partition(key: Int, numPartitions: Int): Int = { - val result = key % numPartitions - - if(result < 0) { - result + numPartitions - } else { - result - } - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala deleted file mode 100644 index 3b948c0..0000000 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/LabeledVector.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common - -import org.apache.flink.ml.math.Vector - -/** This class represents a vector with an associated label as it is required for many supervised - * learning tasks. - * - * @param label Label of the data point - * @param vector Data point - */ -case class LabeledVector(label: Double, vector: Vector) extends Serializable { - - override def equals(obj: Any): Boolean = { - obj match { - case labeledVector: LabeledVector => - vector.equals(labeledVector.vector) && label.equals(labeledVector.label) - case _ => false - } - } - - override def toString: String = { - s"LabeledVector($label, $vector)" - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala deleted file mode 100644 index 77d2d46..0000000 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/ParameterMap.scala +++ /dev/null @@ -1,121 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common - -import scala.collection.mutable - -/** - * Map used to store configuration parameters for algorithms. The parameter - * values are stored in a [[Map]] being identified by a [[Parameter]] object. ParameterMaps can - * be fused. This operation is left associative, meaning that latter ParameterMaps can override - * parameter values defined in a preceding ParameterMap. - * - * @param map Map containing parameter settings - */ -class ParameterMap(val map: mutable.Map[Parameter[_], Any]) extends Serializable { - - def this() = { - this(new mutable.HashMap[Parameter[_], Any]()) - } - - /** - * Adds a new parameter value to the ParameterMap. - * - * @param parameter Key - * @param value Value associated with the given key - * @tparam T Type of value - */ - def add[T](parameter: Parameter[T], value: T): ParameterMap = { - map += (parameter -> value) - this - } - - /** - * Retrieves a parameter value associated to a given key. The value is returned as an Option. - * If there is no value associated to the given key, then the default value of the [[Parameter]] - * is returned. - * - * @param parameter Key - * @tparam T Type of the value to retrieve - * @return Some(value) if an value is associated to the given key, otherwise the default value - * defined by parameter - */ - def get[T](parameter: Parameter[T]): Option[T] = { - if(map.isDefinedAt(parameter)) { - map.get(parameter).asInstanceOf[Option[T]] - } else { - parameter.defaultValue - } - } - - /** - * Retrieves a parameter value associated to a given key. If there is no value contained in the - * map, then the default value of the [[Parameter]] is checked. If the default value is defined, - * then it is returned. If the default is undefined, then a [[NoSuchElementException]] is thrown. - * - * @param parameter Key - * @tparam T Type of value - * @return Value associated with the given key or its default value - */ - def apply[T](parameter: Parameter[T]): T = { - if(map.isDefinedAt(parameter)) { - map(parameter).asInstanceOf[T] - } else { - parameter.defaultValue match { - case Some(value) => value - case None => throw new NoSuchElementException(s"Could not retrieve " + - s"parameter value $parameter.") - } - } - } - - /** - * Adds the parameter values contained in parameters to itself. - * - * @param parameters [[ParameterMap]] containing the parameter values to be added - * @return this after inserting the parameter values from parameters - */ - def ++(parameters: ParameterMap): ParameterMap = { - val result = new ParameterMap(map) - result.map ++= parameters.map - - result - } -} - -object ParameterMap { - val Empty = new ParameterMap - - def apply(): ParameterMap = { - new ParameterMap - } -} - -/** - * Base trait for parameter keys - * - * @tparam T Type of parameter value associated to this parameter key - */ -trait Parameter[T] { - - /** - * Default value of parameter. If no such value exists, then returns [[None]] - */ - val defaultValue: Option[T] -} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala deleted file mode 100644 index 4628c71..0000000 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common - -import org.apache.flink.ml.math.Vector - -// TODO(tvas): This provides an abstraction for the weights -// but at the same time it leads to the creation of many objects as we have to pack and unpack -// the weights and the intercept often during SGD. - -/** This class represents a weight vector with an intercept, as it is required for many supervised - * learning tasks - * @param weights The vector of weights - * @param intercept The intercept (bias) weight - */ -case class WeightVector(weights: Vector, intercept: Double) extends Serializable {} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WithParameters.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WithParameters.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WithParameters.scala deleted file mode 100644 index 24ac9e3..0000000 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WithParameters.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.common - -/** - * Adds a [[ParameterMap]] which can be used to store configuration values - */ -trait WithParameters { - val parameters = new ParameterMap -} http://git-wip-us.apache.org/repos/asf/flink/blob/e9bf13d8/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala deleted file mode 100644 index 8ea3b65..0000000 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BLAS.scala +++ /dev/null @@ -1,291 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.ml.math - -import com.github.fommil.netlib.{BLAS => NetlibBLAS, F2jBLAS} -import com.github.fommil.netlib.BLAS.{getInstance => NativeBLAS} - -/** - * BLAS routines for vectors and matrices. - * - * Original code from the Apache Spark project: - * http://git.io/vfZUe - */ -object BLAS extends Serializable { - - @transient private var _f2jBLAS: NetlibBLAS = _ - @transient private var _nativeBLAS: NetlibBLAS = _ - - // For level-1 routines, we use Java implementation. - private def f2jBLAS: NetlibBLAS = { - if (_f2jBLAS == null) { - _f2jBLAS = new F2jBLAS - } - _f2jBLAS - } - - /** - * y += a * x - */ - def axpy(a: Double, x: Vector, y: Vector): Unit = { - require(x.size == y.size) - y match { - case dy: DenseVector => - x match { - case sx: SparseVector => - axpy(a, sx, dy) - case dx: DenseVector => - axpy(a, dx, dy) - case _ => - throw new UnsupportedOperationException( - s"axpy doesn't support x type ${x.getClass}.") - } - case _ => - throw new IllegalArgumentException( - s"axpy only supports adding to a dense vector but got type ${y.getClass}.") - } - } - - /** - * y += a * x - */ - private def axpy(a: Double, x: DenseVector, y: DenseVector): Unit = { - val n = x.size - f2jBLAS.daxpy(n, a, x.data, 1, y.data, 1) - } - - /** - * y += a * x - */ - private def axpy(a: Double, x: SparseVector, y: DenseVector): Unit = { - val xValues = x.data - val xIndices = x.indices - val yValues = y.data - val nnz = xIndices.size - - if (a == 1.0) { - var k = 0 - while (k < nnz) { - yValues(xIndices(k)) += xValues(k) - k += 1 - } - } else { - var k = 0 - while (k < nnz) { - yValues(xIndices(k)) += a * xValues(k) - k += 1 - } - } - } - - /** - * dot(x, y) - */ - def dot(x: Vector, y: Vector): Double = { - require(x.size == y.size, - "BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" + - " x.size = " + x.size + ", y.size = " + y.size) - (x, y) match { - case (dx: DenseVector, dy: DenseVector) => - dot(dx, dy) - case (sx: SparseVector, dy: DenseVector) => - dot(sx, dy) - case (dx: DenseVector, sy: SparseVector) => - dot(sy, dx) - case (sx: SparseVector, sy: SparseVector) => - dot(sx, sy) - case _ => - throw new IllegalArgumentException(s"dot doesn't support (${x.getClass}, ${y.getClass}).") - } - } - - /** - * dot(x, y) - */ - private def dot(x: DenseVector, y: DenseVector): Double = { - val n = x.size - f2jBLAS.ddot(n, x.data, 1, y.data, 1) - } - - /** - * dot(x, y) - */ - private def dot(x: SparseVector, y: DenseVector): Double = { - val xValues = x.data - val xIndices = x.indices - val yValues = y.data - val nnz = xIndices.size - - var sum = 0.0 - var k = 0 - while (k < nnz) { - sum += xValues(k) * yValues(xIndices(k)) - k += 1 - } - sum - } - - /** - * dot(x, y) - */ - private def dot(x: SparseVector, y: SparseVector): Double = { - val xValues = x.data - val xIndices = x.indices - val yValues = y.data - val yIndices = y.indices - val nnzx = xIndices.size - val nnzy = yIndices.size - - var kx = 0 - var ky = 0 - var sum = 0.0 - // y catching x - while (kx < nnzx && ky < nnzy) { - val ix = xIndices(kx) - while (ky < nnzy && yIndices(ky) < ix) { - ky += 1 - } - if (ky < nnzy && yIndices(ky) == ix) { - sum += xValues(kx) * yValues(ky) - ky += 1 - } - kx += 1 - } - sum - } - - /** - * y = x - */ - def copy(x: Vector, y: Vector): Unit = { - val n = y.size - require(x.size == n) - y match { - case dy: DenseVector => - x match { - case sx: SparseVector => - val sxIndices = sx.indices - val sxValues = sx.data - val dyValues = dy.data - val nnz = sxIndices.size - - var i = 0 - var k = 0 - while (k < nnz) { - val j = sxIndices(k) - while (i < j) { - dyValues(i) = 0.0 - i += 1 - } - dyValues(i) = sxValues(k) - i += 1 - k += 1 - } - while (i < n) { - dyValues(i) = 0.0 - i += 1 - } - case dx: DenseVector => - Array.copy(dx.data, 0, dy.data, 0, n) - } - case _ => - throw new IllegalArgumentException(s"y must be dense in copy but got ${y.getClass}") - } - } - - /** - * x = a * x - */ - def scal(a: Double, x: Vector): Unit = { - x match { - case sx: SparseVector => - f2jBLAS.dscal(sx.data.size, a, sx.data, 1) - case dx: DenseVector => - f2jBLAS.dscal(dx.data.size, a, dx.data, 1) - case _ => - throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") - } - } - - // For level-3 routines, we use the native BLAS. - private def nativeBLAS: NetlibBLAS = { - if (_nativeBLAS == null) { - _nativeBLAS = NativeBLAS - } - _nativeBLAS - } - - /** - * A := alpha * x * x^T^ + A - * @param alpha a real scalar that will be multiplied to x * x^T^. - * @param x the vector x that contains the n elements. - * @param A the symmetric matrix A. Size of n x n. - */ - def syr(alpha: Double, x: Vector, A: DenseMatrix) { - val mA = A.numRows - val nA = A.numCols - require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA") - require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}") - - x match { - case dv: DenseVector => syr(alpha, dv, A) - case sv: SparseVector => syr(alpha, sv, A) - case _ => - throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.") - } - } - - private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) { - val nA = A.numRows - val mA = A.numCols - - nativeBLAS.dsyr("U", x.size, alpha, x.data, 1, A.data, nA) - - // Fill lower triangular part of A - var i = 0 - while (i < mA) { - var j = i + 1 - while (j < nA) { - A(j, i) = A(i, j) - j += 1 - } - i += 1 - } - } - - private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) { - val mA = A.numCols - val xIndices = x.indices - val xValues = x.data - val nnz = xValues.length - val Avalues = A.data - - var i = 0 - while (i < nnz) { - val multiplier = alpha * xValues(i) - val offset = xIndices(i) * mA - var j = 0 - while (j < nnz) { - Avalues(xIndices(j) + offset) += multiplier * xValues(j) - j += 1 - } - i += 1 - } - } -}