This is an automated email from the ASF dual-hosted git repository.
diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push:
new 8781834 [feature] read doris via arrow flight sql (#227)
8781834 is described below
commit 8781834868291c57d4d00eafc3a0fcda2b6914e2
Author: gnehil <[email protected]>
AuthorDate: Tue Aug 27 16:24:53 2024 +0800
[feature] read doris via arrow flight sql (#227)
---
spark-doris-connector/pom.xml | 130 +++++++++++++++++++--
.../doris/spark/cfg/ConfigurationOptions.java | 5 +
.../org/apache/doris/spark/cfg/SparkSettings.java | 19 ++-
.../doris/spark/rest/PartitionDefinition.java | 34 +++---
.../org/apache/doris/spark/rest/RestService.java | 25 ++--
.../apache/doris/spark/serialization/RowBatch.java | 81 ++++++++-----
.../doris/spark/rdd/AbstractDorisRDDIterator.scala | 6 +-
.../doris/spark/rdd/AbstractValueReader.scala | 36 ++++++
.../doris/spark/rdd/ScalaADBCValueReader.scala | 128 ++++++++++++++++++++
.../org/apache/doris/spark/rdd/ScalaDorisRDD.scala | 23 ++--
.../apache/doris/spark/rdd/ScalaValueReader.scala | 10 +-
.../spark/sql/ScalaDorisRowADBCValueReader.scala | 50 ++++++++
.../apache/doris/spark/sql/ScalaDorisRowRDD.scala | 12 +-
.../scala/org/apache/doris/spark/sql/Utils.scala | 30 +++++
.../org/apache/doris/spark/sql/TestUtils.scala | 27 ++++-
15 files changed, 510 insertions(+), 106 deletions(-)
diff --git a/spark-doris-connector/pom.xml b/spark-doris-connector/pom.xml
index a9fd180..bb82e6b 100644
--- a/spark-doris-connector/pom.xml
+++ b/spark-doris-connector/pom.xml
@@ -73,10 +73,10 @@
<scala.version>2.12.10</scala.version>
<scala.major.version>2.12</scala.major.version>
<libthrift.version>0.16.0</libthrift.version>
- <arrow.version>13.0.0</arrow.version>
+ <arrow.version>15.0.2</arrow.version>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.scm.id>github</project.scm.id>
- <netty.version>4.1.77.Final</netty.version>
+ <netty.version>4.1.104.Final</netty.version>
<fasterxml.jackson.version>2.13.5</fasterxml.jackson.version>
<thrift-service.version>1.0.1</thrift-service.version>
<testcontainers.version>1.17.6</testcontainers.version>
@@ -94,12 +94,6 @@
</exclusion>
</exclusions>
</dependency>
- <dependency>
- <groupId>io.netty</groupId>
- <artifactId>netty-all</artifactId>
- <version>${netty.version}</version>
- <scope>provided</scope>
- </dependency>
<dependency>
<groupId>org.apache.spark</groupId>
@@ -248,6 +242,93 @@
<version>4.5.13</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-hive_${scala.major.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.arrow.adbc</groupId>
+ <artifactId>adbc-driver-flight-sql</artifactId>
+ <version>0.13.0</version>
+ <exclusions>
+ <exclusion>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>flight-sql-jdbc-core</artifactId>
+ </exclusion>
+ <exclusion>
+ <artifactId>arrow-memory-netty</artifactId>
+ <groupId>org.apache.arrow</groupId>
+ </exclusion>
+ <exclusion>
+ <artifactId>arrow-memory-core</artifactId>
+ <groupId>org.apache.arrow</groupId>
+ </exclusion>
+ <exclusion>
+ <artifactId>arrow-format</artifactId>
+ <groupId>org.apache.arrow</groupId>
+ </exclusion>
+ <exclusion>
+ <artifactId>arrow-vector</artifactId>
+ <groupId>org.apache.arrow</groupId>
+ </exclusion>
+ <exclusion>
+ <artifactId>grpc-netty</artifactId>
+ <groupId>io.grpc</groupId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-netty</artifactId>
+ <version>1.60.0</version>
+ <exclusions>
+ <exclusion>
+ <artifactId>netty-codec-http2</artifactId>
+ <groupId>io.netty</groupId>
+ </exclusion>
+ <exclusion>
+ <artifactId>netty-handler-proxy</artifactId>
+ <groupId>io.netty</groupId>
+ </exclusion>
+ <exclusion>
+ <artifactId>netty-transport-native-unix-common</artifactId>
+ <groupId>io.netty</groupId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-codec-http2</artifactId>
+ <version>${netty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-handler-proxy</artifactId>
+ <version>${netty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-transport-native-unix-common</artifactId>
+ <version>${netty.version}</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>flight-sql-jdbc-core</artifactId>
+ <version>${arrow.version}</version>
+ </dependency>
+
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>3.22.3</version>
+ </dependency>
+
</dependencies>
<build>
@@ -291,7 +372,7 @@
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
- <version>3.2.1</version>
+ <version>3.4.1</version>
<executions>
<execution>
<id>scala-compile-first</id>
@@ -317,8 +398,20 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
- <version>3.2.1</version>
+ <version>3.4.1</version>
<configuration>
+ <filters>
+ <filter>
+ <!-- Do not copy the signatures in the
META-INF folder.
+ Otherwise, this might cause SecurityExceptions
when using the JAR. -->
+ <artifact>*:*</artifact>
+ <excludes>
+ <exclude>META-INF/*.SF</exclude>
+ <exclude>META-INF/*.DSA</exclude>
+ <exclude>META-INF/*.RSA</exclude>
+ </excludes>
+ </filter>
+ </filters>
<artifactSet>
<excludes>
<exclude>com.google.code.findbugs:*</exclude>
@@ -355,7 +448,20 @@
<pattern>org.apache.http</pattern>
<shadedPattern>org.apache.doris.shaded.org.apache.http</shadedPattern>
</relocation>
+ <relocation>
+ <pattern>io.grpc</pattern>
+
<shadedPattern>org.apache.doris.shaded.io.grpc</shadedPattern>
+ </relocation>
+ <relocation>
+ <pattern>com.google</pattern>
+
<shadedPattern>org.apache.doris.shaded.com.google</shadedPattern>
+ </relocation>
</relocations>
+ <transformers>
+ <transformer
+
implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
+ </transformers>
+ <!-- <minimizeJar>true</minimizeJar> -->
</configuration>
<executions>
<execution>
@@ -370,8 +476,8 @@
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
- <source>8</source>
- <target>8</target>
+ <source>1.8</source>
+ <target>1.8</target>
</configuration>
</plugin>
<plugin>
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
index 8f64c74..68f4ba8 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
@@ -161,4 +161,9 @@ public interface ConfigurationOptions {
"off_mode"
)));
+ String DORIS_READ_MODE = "doris.read.mode";
+ String DORIS_READ_MODE_DEFAULT = "thrift";
+
+ String DORIS_ARROW_FLIGHT_SQL_PORT = "doris.arrow-flight-sql.port";
+
}
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/SparkSettings.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/SparkSettings.java
index 39fcd75..1448d2f 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/SparkSettings.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/SparkSettings.java
@@ -17,16 +17,14 @@
package org.apache.doris.spark.cfg;
-import java.util.Properties;
-
-import org.apache.spark.SparkConf;
-
import com.google.common.base.Preconditions;
-
+import org.apache.spark.SparkConf;
import scala.Option;
import scala.Serializable;
import scala.Tuple2;
+import java.util.Properties;
+
public class SparkSettings extends Settings implements Serializable {
private final SparkConf cfg;
@@ -36,6 +34,16 @@ public class SparkSettings extends Settings implements
Serializable {
this.cfg = cfg;
}
+ public static SparkSettings fromProperties(Properties props) {
+ SparkConf sparkConf = new SparkConf();
+ props.forEach((k, v) -> {
+ if (k instanceof String) {
+ sparkConf.set((String) k, v.toString());
+ }
+ });
+ return new SparkSettings(sparkConf);
+ }
+
public SparkSettings copy() {
return new SparkSettings(cfg.clone());
}
@@ -74,4 +82,5 @@ public class SparkSettings extends Settings implements
Serializable {
return props;
}
+
}
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/PartitionDefinition.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/PartitionDefinition.java
index 0c2aae3..baa517a 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/PartitionDefinition.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/PartitionDefinition.java
@@ -17,16 +17,16 @@
package org.apache.doris.spark.rest;
+import org.apache.doris.spark.cfg.PropertiesSettings;
+import org.apache.doris.spark.cfg.Settings;
+import org.apache.doris.spark.exception.IllegalArgumentException;
+
import java.io.Serializable;
import java.util.Collections;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;
-import org.apache.doris.spark.cfg.PropertiesSettings;
-import org.apache.doris.spark.cfg.Settings;
-import org.apache.doris.spark.exception.IllegalArgumentException;
-
/**
* Doris RDD partition info.
*/
@@ -124,12 +124,12 @@ public class PartitionDefinition implements Serializable,
Comparable<PartitionDe
return false;
}
PartitionDefinition that = (PartitionDefinition) o;
- return Objects.equals(database, that.database) &&
- Objects.equals(table, that.table) &&
- Objects.equals(beAddress, that.beAddress) &&
- Objects.equals(tabletIds, that.tabletIds) &&
- Objects.equals(queryPlan, that.queryPlan) &&
- Objects.equals(serializedSettings, that.serializedSettings);
+ return Objects.equals(database, that.database)
+ && Objects.equals(table, that.table)
+ && Objects.equals(beAddress, that.beAddress)
+ && Objects.equals(tabletIds, that.tabletIds)
+ && Objects.equals(queryPlan, that.queryPlan)
+ && Objects.equals(serializedSettings, that.serializedSettings);
}
@Override
@@ -144,12 +144,12 @@ public class PartitionDefinition implements Serializable,
Comparable<PartitionDe
@Override
public String toString() {
- return "PartitionDefinition{" +
- ", database='" + database + '\'' +
- ", table='" + table + '\'' +
- ", beAddress='" + beAddress + '\'' +
- ", tabletIds=" + tabletIds +
- ", queryPlan='" + queryPlan + '\'' +
- '}';
+ return "PartitionDefinition{"
+ + ", database='" + database + '\''
+ + ", table='" + table + '\''
+ + ", beAddress='" + beAddress + '\''
+ + ", tabletIds=" + tabletIds
+ + ", queryPlan='" + queryPlan + '\''
+ + '}';
}
}
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/RestService.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/RestService.java
index 3f3516f..50432a6 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/RestService.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/RestService.java
@@ -42,6 +42,7 @@ import org.apache.doris.spark.rest.models.QueryPlan;
import org.apache.doris.spark.rest.models.Schema;
import org.apache.doris.spark.rest.models.Tablet;
import org.apache.doris.spark.sql.SchemaUtils;
+import org.apache.doris.spark.sql.Utils;
import org.apache.doris.spark.util.HttpUtil;
import org.apache.doris.spark.util.URLs;
@@ -51,7 +52,6 @@ import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.json.JsonMapper;
import com.google.common.annotations.VisibleForTesting;
-import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpStatus;
@@ -64,6 +64,7 @@ import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.util.EntityUtils;
import org.slf4j.Logger;
+import scala.Option;
import java.io.IOException;
import java.io.Serializable;
@@ -227,23 +228,11 @@ public class RestService implements Serializable {
String[] tableIdentifiers =
parseIdentifier(cfg.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER),
logger);
String readFields =
cfg.getProperty(ConfigurationOptions.DORIS_READ_FIELD, "*");
- if (!"*".equals(readFields)) {
- String[] readFieldArr = readFields.split(",");
- String[] bitmapColumns =
cfg.getProperty(SchemaUtils.DORIS_BITMAP_COLUMNS(), "").split(",");
- String[] hllColumns =
cfg.getProperty(SchemaUtils.DORIS_HLL_COLUMNS(), "").split(",");
- for (int i = 0; i < readFieldArr.length; i++) {
- String readFieldName = readFieldArr[i].replaceAll("`", "");
- if (ArrayUtils.contains(bitmapColumns, readFieldName)
- || ArrayUtils.contains(hllColumns, readFieldName)) {
- readFieldArr[i] = "'READ UNSUPPORTED' AS " +
readFieldArr[i];
- }
- }
- readFields = StringUtils.join(readFieldArr, ",");
- }
- String sql = "select " + readFields + " from `" + tableIdentifiers[0]
+ "`.`" + tableIdentifiers[1] + "`";
- if
(!StringUtils.isEmpty(cfg.getProperty(ConfigurationOptions.DORIS_FILTER_QUERY)))
{
- sql += " where " +
cfg.getProperty(ConfigurationOptions.DORIS_FILTER_QUERY);
- }
+ String[] bitmapColumns =
cfg.getProperty(SchemaUtils.DORIS_BITMAP_COLUMNS(), "").split(",");
+ String[] hllColumns = cfg.getProperty(SchemaUtils.DORIS_HLL_COLUMNS(),
"").split(",");
+ String sql = Utils.generateQueryStatement(readFields.split(","),
bitmapColumns, hllColumns,
+ "`" + tableIdentifiers[0] + "`.`" + tableIdentifiers[1] + "`",
+ cfg.getProperty(ConfigurationOptions.DORIS_FILTER_QUERY, ""),
Option.empty());
logger.debug("Query SQL Sending to Doris FE is: '{}'.", sql);
String finalSql = sql;
diff --git
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
index 319bd3c..c3e70e9 100644
---
a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
+++
b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java
@@ -45,6 +45,7 @@ import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.complex.impl.UnionMapReader;
+import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.Types.MinorType;
@@ -79,6 +80,10 @@ import java.util.Objects;
*/
public class RowBatch {
private static final Logger logger =
LoggerFactory.getLogger(RowBatch.class);
+
+ private final List<Row> rowBatch = new ArrayList<>();
+ private final ArrowReader arrowReader;
+ private final Schema schema;
private static final ZoneId DEFAULT_ZONE_ID = ZoneId.systemDefault();
private static final DateTimeFormatter DATE_TIME_FORMATTER = new
DateTimeFormatterBuilder()
@@ -93,10 +98,7 @@ public class RowBatch {
private final DateTimeFormatter dateTimeV2Formatter =
DateTimeFormatter.ofPattern(DATETIMEV2_PATTERN);
private final DateTimeFormatter dateFormatter =
DateTimeFormatter.ofPattern("yyyy-MM-dd");
- private final List<Row> rowBatch = new ArrayList<>();
- private final ArrowStreamReader arrowStreamReader;
- private final RootAllocator rootAllocator;
- private final Schema schema;
+ private RootAllocator rootAllocator = null;
// offset for iterate the rowBatch
private int offsetInRowBatch = 0;
private int rowCountInOneBatch = 0;
@@ -104,32 +106,15 @@ public class RowBatch {
private List<FieldVector> fieldVectors;
public RowBatch(TScanBatchResult nextResult, Schema schema) throws
DorisException {
- this.schema = schema;
+
this.rootAllocator = new RootAllocator(Integer.MAX_VALUE);
- this.arrowStreamReader = new ArrowStreamReader(
- new ByteArrayInputStream(nextResult.getRows()),
- rootAllocator
- );
+ this.arrowReader = new ArrowStreamReader(new
ByteArrayInputStream(nextResult.getRows()), rootAllocator);
+ this.schema = schema;
+
try {
- VectorSchemaRoot root = arrowStreamReader.getVectorSchemaRoot();
- while (arrowStreamReader.loadNextBatch()) {
- fieldVectors = root.getFieldVectors();
- if (fieldVectors.size() > schema.size()) {
- logger.error("Data schema size '{}' should not be bigger
than arrow field size '{}'.",
- schema.size(), fieldVectors.size());
- throw new DorisException("Load Doris data failed, schema
size of fetch data is wrong.");
- }
- if (fieldVectors.isEmpty() || root.getRowCount() == 0) {
- logger.debug("One batch in arrow has no data.");
- continue;
- }
- rowCountInOneBatch = root.getRowCount();
- // init the rowBatch
- for (int i = 0; i < rowCountInOneBatch; ++i) {
- rowBatch.add(new Row(fieldVectors.size()));
- }
- convertArrowToRowBatch();
- readRowCount += root.getRowCount();
+ VectorSchemaRoot root = arrowReader.getVectorSchemaRoot();
+ while (arrowReader.loadNextBatch()) {
+ readBatch(root);
}
} catch (Exception e) {
logger.error("Read Doris Data failed because: ", e);
@@ -137,6 +122,42 @@ public class RowBatch {
} finally {
close();
}
+
+ }
+
+ public RowBatch(ArrowReader reader, Schema schema) throws DorisException {
+
+ this.arrowReader = reader;
+ this.schema = schema;
+
+ try {
+ VectorSchemaRoot root = arrowReader.getVectorSchemaRoot();
+ readBatch(root);
+ } catch (Exception e) {
+ logger.error("Read Doris Data failed because: ", e);
+ throw new DorisException(e.getMessage());
+ }
+
+ }
+
+ private void readBatch(VectorSchemaRoot root) throws DorisException {
+ fieldVectors = root.getFieldVectors();
+ if (fieldVectors.size() > schema.size()) {
+ logger.error("Data schema size '{}' should not be bigger than
arrow field size '{}'.",
+ schema.size(), fieldVectors.size());
+ throw new DorisException("Load Doris data failed, schema size of
fetch data is wrong.");
+ }
+ if (fieldVectors.isEmpty() || root.getRowCount() == 0) {
+ logger.debug("One batch in arrow has no data.");
+ return;
+ }
+ rowCountInOneBatch = root.getRowCount();
+ // init the rowBatch
+ for (int i = 0; i < rowCountInOneBatch; ++i) {
+ rowBatch.add(new Row(fieldVectors.size()));
+ }
+ convertArrowToRowBatch();
+ readRowCount += root.getRowCount();
}
public static LocalDateTime longToLocalDateTime(long time) {
@@ -505,8 +526,8 @@ public class RowBatch {
public void close() {
try {
- if (arrowStreamReader != null) {
- arrowStreamReader.close();
+ if (arrowReader != null) {
+ arrowReader.close();
}
if (rootAllocator != null) {
rootAllocator.close();
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala
index 902c634..8e5f661 100644
---
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala
@@ -33,14 +33,14 @@ private[spark] abstract class AbstractDorisRDDIterator[T](
private var closed = false
// the reader obtain data from Doris BE
- private lazy val reader = {
+ private lazy val reader: AbstractValueReader = {
initialized = true
val settings = partition.settings()
initReader(settings)
val valueReaderName = settings.getProperty(DORIS_VALUE_READER_CLASS)
- logger.debug(s"Use value reader '$valueReaderName'.")
+ logger.info(s"Use value reader '$valueReaderName'.")
val cons =
Class.forName(valueReaderName).getDeclaredConstructor(classOf[PartitionDefinition],
classOf[Settings])
- cons.newInstance(partition, settings).asInstanceOf[ScalaValueReader]
+ cons.newInstance(partition, settings).asInstanceOf[AbstractValueReader]
}
context.addTaskCompletionListener(new TaskCompletionListener() {
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractValueReader.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractValueReader.scala
new file mode 100644
index 0000000..3c8acf6
--- /dev/null
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractValueReader.scala
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.rdd
+
+import org.apache.doris.spark.serialization.RowBatch
+
+trait AbstractValueReader {
+
+ protected var rowBatch: RowBatch = _
+
+ def hasNext: Boolean
+
+ /**
+ * get next value.
+ * @return next value
+ */
+ def next: AnyRef
+
+ def close(): Unit
+
+}
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaADBCValueReader.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaADBCValueReader.scala
new file mode 100644
index 0000000..3cd3da0
--- /dev/null
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaADBCValueReader.scala
@@ -0,0 +1,128 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.rdd
+
+import org.apache.arrow.adbc.core.{AdbcConnection, AdbcDriver, AdbcStatement}
+import org.apache.arrow.adbc.driver.flightsql.FlightSqlDriver
+import org.apache.arrow.flight.Location
+import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
+import org.apache.arrow.vector.ipc.ArrowReader
+import org.apache.commons.lang3.exception.ExceptionUtils
+import org.apache.doris.spark.cfg.{ConfigurationOptions, Settings,
SparkSettings}
+import org.apache.doris.spark.exception.ShouldNeverHappenException
+import org.apache.doris.spark.rest.{PartitionDefinition, RestService}
+import org.apache.doris.spark.serialization.RowBatch
+import org.apache.doris.spark.sql.{SchemaUtils, Utils}
+import org.apache.doris.spark.util.ErrorMessages.SHOULD_NOT_HAPPEN_MESSAGE
+import org.apache.spark.internal.Logging
+
+import java.util.concurrent.atomic.AtomicBoolean
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+class ScalaADBCValueReader(partition: PartitionDefinition, settings: Settings)
extends AbstractValueReader with Logging {
+
+ private[this] val eos: AtomicBoolean = new AtomicBoolean(false)
+
+ private lazy val schema =
RestService.getSchema(SparkSettings.fromProperties(settings.asProperties()),
log)
+
+ private lazy val conn: AdbcConnection = {
+ // val loader = ClassLoader.getSystemClassLoader
+ // val classesField = classOf[ClassLoader].getDeclaredField("classes")
+ // classesField.setAccessible(true)
+ // val classes =
classesField.get(loader).asInstanceOf[java.util.Vector[Any]]
+ // classes.forEach(clazz => println(clazz.asInstanceOf[Class[_]].getName))
+ //
Class.forName("org.apache.doris.shaded.org.apache.arrow.memory.RootAllocator")
+ var allocator: BufferAllocator = null
+ try {
+ allocator = new RootAllocator()
+ } catch {
+ case e: Throwable => println(ExceptionUtils.getStackTrace(e))
+ throw e;
+ }
+ val driver = new FlightSqlDriver(allocator)
+ val params = mutable.HashMap[String, AnyRef]().asJava
+ AdbcDriver.PARAM_URI.set(params, Location.forGrpcInsecure(
+ settings.getProperty(ConfigurationOptions.DORIS_FENODES).split(":")(0),
+
settings.getIntegerProperty(ConfigurationOptions.DORIS_ARROW_FLIGHT_SQL_PORT)
+ ).getUri.toString)
+ AdbcDriver.PARAM_USERNAME.set(params,
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER))
+ AdbcDriver.PARAM_PASSWORD.set(params,
settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD))
+ val database = driver.open(params)
+ database.connect()
+ }
+
+ private lazy val stmt: AdbcStatement = conn.createStatement()
+
+ private lazy val queryResult: AdbcStatement.QueryResult = {
+ val flightSql =
Utils.generateQueryStatement(settings.getProperty(ConfigurationOptions.DORIS_READ_FIELD,
"*").split(","),
+ settings.getProperty(SchemaUtils.DORIS_BITMAP_COLUMNS, "").split(","),
+ settings.getProperty(SchemaUtils.DORIS_HLL_COLUMNS, "").split(","),
+ s"`${partition.getDatabase}`.`${partition.getTable}`",
+ settings.getProperty(ConfigurationOptions.DORIS_FILTER_QUERY, ""),
+ Some(partition)
+ )
+ log.info(s"flightSql: $flightSql")
+ stmt.setSqlQuery(flightSql)
+ stmt.executeQuery()
+ }
+
+ private lazy val arrowReader: ArrowReader = queryResult.getReader
+
+ override def hasNext: Boolean = {
+ if (!eos.get && (rowBatch == null || !rowBatch.hasNext)) {
+ eos.set(!arrowReader.loadNextBatch())
+ if (!eos.get) {
+ rowBatch = new RowBatch(arrowReader, schema)
+ }
+ }
+ !eos.get
+ }
+
+ /**
+ * get next value.
+ *
+ * @return next value
+ */
+ override def next: AnyRef = {
+ if (!hasNext) {
+ logError(SHOULD_NOT_HAPPEN_MESSAGE)
+ throw new ShouldNeverHappenException
+ }
+ rowBatch.next
+ }
+
+ override def close(): Unit = {
+ if (rowBatch != null) {
+ rowBatch.close()
+ }
+ if (arrowReader != null) {
+ arrowReader.close()
+ }
+ if (queryResult != null) {
+ queryResult.close()
+ }
+ if (stmt != null) {
+ stmt.close()
+ }
+ if (conn != null) {
+ conn.close()
+ }
+ }
+
+}
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaDorisRDD.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaDorisRDD.scala
index 0ff8bbd..768611c 100644
---
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaDorisRDD.scala
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaDorisRDD.scala
@@ -18,29 +18,32 @@
package org.apache.doris.spark.rdd
import scala.reflect.ClassTag
-
import org.apache.doris.spark.cfg.ConfigurationOptions.DORIS_VALUE_READER_CLASS
-import org.apache.doris.spark.cfg.Settings
+import org.apache.doris.spark.cfg.{ConfigurationOptions, Settings}
import org.apache.doris.spark.rest.PartitionDefinition
-
import org.apache.spark.{Partition, SparkContext, TaskContext}
private[spark] class ScalaDorisRDD[T: ClassTag](
- sc: SparkContext,
- params: Map[String, String] = Map.empty)
- extends AbstractDorisRDD[T](sc, params) {
+ sc: SparkContext,
+ params: Map[String, String] =
Map.empty)
+ extends AbstractDorisRDD[T](sc, params) {
override def compute(split: Partition, context: TaskContext):
ScalaDorisRDDIterator[T] = {
new ScalaDorisRDDIterator(context,
split.asInstanceOf[DorisPartition].dorisPartition)
}
}
private[spark] class ScalaDorisRDDIterator[T](
- context: TaskContext,
- partition: PartitionDefinition)
- extends AbstractDorisRDDIterator[T](context, partition) {
+ context: TaskContext,
+ partition: PartitionDefinition)
+ extends AbstractDorisRDDIterator[T](context, partition) {
override def initReader(settings: Settings): Unit = {
- settings.setProperty(DORIS_VALUE_READER_CLASS,
classOf[ScalaValueReader].getName)
+ settings.getProperty(ConfigurationOptions.DORIS_READ_MODE,
+ ConfigurationOptions.DORIS_READ_MODE_DEFAULT).toUpperCase match {
+ case "THRIFT" => settings.setProperty(DORIS_VALUE_READER_CLASS,
classOf[ScalaValueReader].getName)
+ case "ARROW" => settings.setProperty(DORIS_VALUE_READER_CLASS,
classOf[ScalaADBCValueReader].getName)
+ case mode: String => throw new IllegalArgumentException(s"Unsupported
read mode: $mode")
+ }
}
override def createValue(value: Object): T = {
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
index f9124a6..16707b8 100644
---
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
@@ -45,7 +45,7 @@ import scala.util.control.Breaks
* @param partition Doris RDD partition
* @param settings request configuration
*/
-class ScalaValueReader(partition: PartitionDefinition, settings: Settings)
extends Logging {
+class ScalaValueReader(partition: PartitionDefinition, settings: Settings)
extends AbstractValueReader with Logging {
private[this] lazy val client = new BackendClient(new
Routing(partition.getBeAddress), settings)
@@ -53,8 +53,6 @@ class ScalaValueReader(partition: PartitionDefinition,
settings: Settings) exten
private[this] val eos: AtomicBoolean = new AtomicBoolean(false)
- protected var rowBatch: RowBatch = _
-
// flag indicate if support deserialize Arrow to RowBatch asynchronously
private[this] lazy val deserializeArrowToRowBatchAsync: Boolean = Try {
settings.getProperty(DORIS_DESERIALIZE_ARROW_ASYNC,
DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT.toString).toBoolean
@@ -173,7 +171,7 @@ class ScalaValueReader(partition: PartitionDefinition,
settings: Settings) exten
* read data and cached in rowBatch.
* @return true if hax next value
*/
- def hasNext: Boolean = {
+ override def hasNext: Boolean = {
var hasNext = false
if (deserializeArrowToRowBatchAsync && asyncThreadStarted) {
// support deserialize Arrow to RowBatch asynchronously
@@ -219,7 +217,7 @@ class ScalaValueReader(partition: PartitionDefinition,
settings: Settings) exten
* get next value.
* @return next value
*/
- def next: AnyRef = {
+ override def next: AnyRef = {
if (!hasNext) {
logError(SHOULD_NOT_HAPPEN_MESSAGE)
throw new ShouldNeverHappenException
@@ -227,7 +225,7 @@ class ScalaValueReader(partition: PartitionDefinition,
settings: Settings) exten
rowBatch.next
}
- def close(): Unit = {
+ override def close(): Unit = {
val closeParams = new TScanCloseParams
closeParams.setContextId(contextId)
lockClient(_.closeScanner(closeParams))
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowADBCValueReader.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowADBCValueReader.scala
new file mode 100644
index 0000000..a658cdc
--- /dev/null
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowADBCValueReader.scala
@@ -0,0 +1,50 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.sql
+
+import org.apache.doris.spark.cfg.ConfigurationOptions.DORIS_READ_FIELD
+import org.apache.doris.spark.cfg.Settings
+import org.apache.doris.spark.exception.ShouldNeverHappenException
+import org.apache.doris.spark.rdd.ScalaADBCValueReader
+import org.apache.doris.spark.rest.PartitionDefinition
+import org.apache.doris.spark.util.ErrorMessages.SHOULD_NOT_HAPPEN_MESSAGE
+import org.slf4j.{Logger, LoggerFactory}
+
+import scala.collection.JavaConverters._
+
+class ScalaDorisRowADBCValueReader(partition: PartitionDefinition, settings:
Settings)
+ extends ScalaADBCValueReader(partition, settings) {
+
+ private val logger: Logger =
LoggerFactory.getLogger(classOf[ScalaDorisRowADBCValueReader].getName)
+
+ val rowOrder: Seq[String] = settings.getProperty(DORIS_READ_FIELD).split(",")
+
+ override def next: AnyRef = {
+ if (!hasNext) {
+ logger.error(SHOULD_NOT_HAPPEN_MESSAGE)
+ throw new ShouldNeverHappenException
+ }
+ val row: ScalaDorisRow = new ScalaDorisRow(rowOrder)
+ rowBatch.next.asScala.zipWithIndex.foreach{
+ case (s, index) if index < row.values.size => row.values.update(index, s)
+ case _ => // nothing
+ }
+ row
+ }
+
+}
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
index b31a54d..6c3bd35 100644
---
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala
@@ -18,10 +18,9 @@
package org.apache.doris.spark.sql
import org.apache.doris.spark.cfg.ConfigurationOptions.DORIS_VALUE_READER_CLASS
-import org.apache.doris.spark.cfg.Settings
+import org.apache.doris.spark.cfg.{ConfigurationOptions, Settings}
import org.apache.doris.spark.rdd.{AbstractDorisRDD, AbstractDorisRDDIterator,
DorisPartition}
import org.apache.doris.spark.rest.PartitionDefinition
-
import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
@@ -43,8 +42,13 @@ private[spark] class ScalaDorisRowRDDIterator(
struct: StructType)
extends AbstractDorisRDDIterator[Row](context, partition) {
- override def initReader(settings: Settings) = {
- settings.setProperty(DORIS_VALUE_READER_CLASS,
classOf[ScalaDorisRowValueReader].getName)
+ override def initReader(settings: Settings): Unit = {
+ settings.getProperty(ConfigurationOptions.DORIS_READ_MODE,
+ ConfigurationOptions.DORIS_READ_MODE_DEFAULT).toUpperCase match {
+ case "THRIFT" => settings.setProperty (DORIS_VALUE_READER_CLASS,
classOf[ScalaDorisRowValueReader].getName)
+ case "ARROW" => settings.setProperty (DORIS_VALUE_READER_CLASS,
classOf[ScalaDorisRowADBCValueReader].getName)
+ case mode: String => throw new IllegalArgumentException(s"Unsupported
read mode: $mode")
+ }
}
override def createValue(value: Object): Row = {
diff --git
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
index 0400b04..2404584 100644
---
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
+++
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
@@ -20,6 +20,7 @@ package org.apache.doris.spark.sql
import org.apache.commons.lang3.StringUtils
import org.apache.doris.spark.cfg.ConfigurationOptions
import org.apache.doris.spark.exception.DorisException
+import org.apache.doris.spark.rest.PartitionDefinition
import org.apache.spark.sql.jdbc.JdbcDialect
import org.apache.spark.sql.sources._
import org.slf4j.Logger
@@ -28,6 +29,7 @@ import java.sql.{Date, Timestamp}
import java.time.{Duration, LocalDate}
import java.util.concurrent.locks.LockSupport
import scala.annotation.tailrec
+import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}
@@ -201,4 +203,32 @@ private[spark] object Utils {
case Failure(exception) => Failure(exception)
}
}
+
+ def generateQueryStatement(readColumns: Array[String], bitmapColumns:
Array[String], hllColumns: Array[String],
+ tableName: String, queryFilter: String,
partitionOpt: Option[PartitionDefinition] = None): String = {
+
+ val columns = {
+ val finalReadColumns = readColumns.clone()
+ if (finalReadColumns(0) != "*" && bitmapColumns.nonEmpty &&
hllColumns.nonEmpty) {
+ for (i <- finalReadColumns.indices) {
+ finalReadColumns(i)
+ val readFieldName = finalReadColumns(i).replaceAll("`", "")
+ if (bitmapColumns.contains(readFieldName) ||
hllColumns.contains(readFieldName)) {
+ finalReadColumns(i) = "'READ UNSUPPORTED' AS " +
finalReadColumns(i)
+ }
+ }
+ }
+ finalReadColumns.mkString(",")
+ }
+
+ val tabletClause = partitionOpt match {
+ case Some(partition) =>
s"TABLET(${partition.getTabletIds.asScala.mkString(",")})"
+ case None => ""
+ }
+ val whereClause = if (queryFilter.isEmpty) "" else s"WHERE $queryFilter"
+
+ s"SELECT $columns FROM $tableName $tabletClause $whereClause".trim
+
+ }
+
}
diff --git
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestUtils.scala
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestUtils.scala
index 7e7919a..e9db609 100644
---
a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestUtils.scala
+++
b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestUtils.scala
@@ -17,14 +17,17 @@
package org.apache.doris.spark.sql
-import org.apache.doris.spark.cfg.ConfigurationOptions
+import org.apache.doris.spark.cfg.{ConfigurationOptions, PropertiesSettings,
Settings}
import org.apache.doris.spark.exception.DorisException
+import org.apache.doris.spark.rest.PartitionDefinition
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
import org.hamcrest.core.StringStartsWith.startsWith
import org.junit._
import org.slf4j.LoggerFactory
+import scala.collection.JavaConverters._
+
class TestUtils extends ExpectedExceptionTest {
private lazy val logger = LoggerFactory.getLogger(classOf[TestUtils])
@@ -132,4 +135,26 @@ class TestUtils extends ExpectedExceptionTest {
thrown.expectMessage(startsWith(s"${ConfigurationOptions.DORIS_REQUEST_AUTH_USER}
cannot use in Doris Datasource,"))
Utils.params(parameters6, logger)
}
+
+ @Test
+ def testGenerateQueryStatement(): Unit = {
+
+ val readColumns = Array[String]("*")
+
+ val partition = new PartitionDefinition("db", "tbl1", new
PropertiesSettings(), "127.0.0.1:8060", Set[java.lang.Long](1L).asJava, "")
+ Assert.assertEquals("SELECT * FROM `db`.`tbl1` TABLET(1)",
+ Utils.generateQueryStatement(readColumns, Array[String](),
Array[String](), "`db`.`tbl1`", "", Some(partition)))
+
+ val readColumns1 = Array[String]("`c1`","`c2`","`c3`")
+
+ val bitmapColumns = Array[String]("c2")
+ val hllColumns = Array[String]("c3")
+
+ val where = "c1 = 10"
+
+ Assert.assertEquals("SELECT `c1`,'READ UNSUPPORTED' AS `c2`,'READ
UNSUPPORTED' AS `c3` FROM `db`.`tbl1` WHERE c1 = 10",
+ Utils.generateQueryStatement(readColumns1, bitmapColumns, hllColumns,
"`db`.`tbl1`", where))
+
+ }
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]