Repository: spark Updated Branches: refs/heads/master 03c27435a -> 3b2b785ec
[SPARK-16675][SQL] Avoid per-record type dispatch in JDBC when writing ## What changes were proposed in this pull request? Currently, `JdbcUtils.savePartition` is doing type-based dispatch for each row to write appropriate values. So, appropriate setters for `PreparedStatement` can be created first according to the schema, and then apply them to each row. This approach is similar with `CatalystWriteSupport`. This PR simply make the setters to avoid this. ## How was this patch tested? Existing tests should cover this. Author: hyukjinkwon <gurwls...@gmail.com> Closes #14323 from HyukjinKwon/SPARK-16675. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3b2b785e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3b2b785e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3b2b785e Branch: refs/heads/master Commit: 3b2b785ece4394ca332377647a6305ea493f411b Parents: 03c2743 Author: hyukjinkwon <gurwls...@gmail.com> Authored: Tue Jul 26 17:14:58 2016 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue Jul 26 17:14:58 2016 +0800 ---------------------------------------------------------------------- .../execution/datasources/jdbc/JDBCRDD.scala | 22 ++-- .../execution/datasources/jdbc/JdbcUtils.scala | 102 ++++++++++++++----- 2 files changed, 88 insertions(+), 36 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3b2b785e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 4c98430..e267e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -322,19 +322,19 @@ private[sql] class JDBCRDD( } } - // A `JDBCValueSetter` is responsible for converting and setting a value from `ResultSet` - // into a field for `MutableRow`. The last argument `Int` means the index for the - // value to be set in the row and also used for the value to retrieve from `ResultSet`. - private type JDBCValueSetter = (ResultSet, MutableRow, Int) => Unit + // A `JDBCValueGetter` is responsible for getting a value from `ResultSet` into a field + // for `MutableRow`. The last argument `Int` means the index for the value to be set in + // the row and also used for the value in `ResultSet`. + private type JDBCValueGetter = (ResultSet, MutableRow, Int) => Unit /** - * Creates `JDBCValueSetter`s according to [[StructType]], which can set + * Creates `JDBCValueGetter`s according to [[StructType]], which can set * each value from `ResultSet` to each field of [[MutableRow]] correctly. */ - def makeSetters(schema: StructType): Array[JDBCValueSetter] = - schema.fields.map(sf => makeSetter(sf.dataType, sf.metadata)) + def makeGetters(schema: StructType): Array[JDBCValueGetter] = + schema.fields.map(sf => makeGetter(sf.dataType, sf.metadata)) - private def makeSetter(dt: DataType, metadata: Metadata): JDBCValueSetter = dt match { + private def makeGetter(dt: DataType, metadata: Metadata): JDBCValueGetter = dt match { case BooleanType => (rs: ResultSet, row: MutableRow, pos: Int) => row.setBoolean(pos, rs.getBoolean(pos + 1)) @@ -489,15 +489,15 @@ private[sql] class JDBCRDD( stmt.setFetchSize(fetchSize) val rs = stmt.executeQuery() - val setters: Array[JDBCValueSetter] = makeSetters(schema) + val getters: Array[JDBCValueGetter] = makeGetters(schema) val mutableRow = new SpecificMutableRow(schema.fields.map(x => x.dataType)) def getNext(): InternalRow = { if (rs.next()) { inputMetrics.incRecordsRead(1) var i = 0 - while (i < setters.length) { - setters(i).apply(rs, mutableRow, i) + while (i < getters.length) { + getters(i).apply(rs, mutableRow, i) if (rs.wasNull) mutableRow.setNullAt(i) i = i + 1 } http://git-wip-us.apache.org/repos/asf/spark/blob/3b2b785e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index cb474cb..81d38e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -154,6 +154,79 @@ object JdbcUtils extends Logging { throw new IllegalArgumentException(s"Can't get JDBC type for ${dt.simpleString}")) } + // A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for + // `PreparedStatement`. The last argument `Int` means the index for the value to be set + // in the SQL statement and also used for the value in `Row`. + private type JDBCValueSetter = (PreparedStatement, Row, Int) => Unit + + private def makeSetter( + conn: Connection, + dialect: JdbcDialect, + dataType: DataType): JDBCValueSetter = dataType match { + case IntegerType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setInt(pos + 1, row.getInt(pos)) + + case LongType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setLong(pos + 1, row.getLong(pos)) + + case DoubleType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setDouble(pos + 1, row.getDouble(pos)) + + case FloatType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setFloat(pos + 1, row.getFloat(pos)) + + case ShortType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setInt(pos + 1, row.getShort(pos)) + + case ByteType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setInt(pos + 1, row.getByte(pos)) + + case BooleanType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setBoolean(pos + 1, row.getBoolean(pos)) + + case StringType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setString(pos + 1, row.getString(pos)) + + case BinaryType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos)) + + case TimestampType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos)) + + case DateType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos)) + + case t: DecimalType => + (stmt: PreparedStatement, row: Row, pos: Int) => + stmt.setBigDecimal(pos + 1, row.getDecimal(pos)) + + case ArrayType(et, _) => + // remove type length parameters from end of type name + val typeName = getJdbcType(et, dialect).databaseTypeDefinition + .toLowerCase.split("\\(")(0) + (stmt: PreparedStatement, row: Row, pos: Int) => + val array = conn.createArrayOf( + typeName, + row.getSeq[AnyRef](pos).toArray) + stmt.setArray(pos + 1, array) + + case _ => + (_: PreparedStatement, _: Row, pos: Int) => + throw new IllegalArgumentException( + s"Can't translate non-null value for field $pos") + } + /** * Saves a partition of a DataFrame to the JDBC database. This is done in * a single database transaction (unless isolation level is "NONE") @@ -215,6 +288,9 @@ object JdbcUtils extends Logging { conn.setTransactionIsolation(finalIsolationLevel) } val stmt = insertStatement(conn, table, rddSchema, dialect) + val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType) + .map(makeSetter(conn, dialect, _)).toArray + try { var rowCount = 0 while (iterator.hasNext) { @@ -225,30 +301,7 @@ object JdbcUtils extends Logging { if (row.isNullAt(i)) { stmt.setNull(i + 1, nullTypes(i)) } else { - rddSchema.fields(i).dataType match { - case IntegerType => stmt.setInt(i + 1, row.getInt(i)) - case LongType => stmt.setLong(i + 1, row.getLong(i)) - case DoubleType => stmt.setDouble(i + 1, row.getDouble(i)) - case FloatType => stmt.setFloat(i + 1, row.getFloat(i)) - case ShortType => stmt.setInt(i + 1, row.getShort(i)) - case ByteType => stmt.setInt(i + 1, row.getByte(i)) - case BooleanType => stmt.setBoolean(i + 1, row.getBoolean(i)) - case StringType => stmt.setString(i + 1, row.getString(i)) - case BinaryType => stmt.setBytes(i + 1, row.getAs[Array[Byte]](i)) - case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) - case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) - case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) - case ArrayType(et, _) => - // remove type length parameters from end of type name - val typeName = getJdbcType(et, dialect).databaseTypeDefinition - .toLowerCase.split("\\(")(0) - val array = conn.createArrayOf( - typeName, - row.getSeq[AnyRef](i).toArray) - stmt.setArray(i + 1, array) - case _ => throw new IllegalArgumentException( - s"Can't translate non-null value for field $i") - } + setters(i).apply(stmt, row, i) } i = i + 1 } @@ -333,5 +386,4 @@ object JdbcUtils extends Logging { getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect, isolationLevel) ) } - } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org