vibhatha commented on code in PR #40056:
URL: https://github.com/apache/arrow/pull/40056#discussion_r1487303596
##########
docs/source/java/substrait.rst:
##########
@@ -148,297 +136,60 @@ This Java program:
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.ipc.ArrowReader;
+ import org.apache.calcite.sql.parser.SqlParseException;
+
+ import java.nio.ByteBuffer;
+ import java.util.Base64;
+ import java.util.Optional;
public class ClientSubstraitExtendedExpressionsCookbook {
- public static void main(String[] args) throws Exception {
- // project and filter dataset using extended expression definition -
03 Expressions:
- // Expression 01 - CONCAT: N_NAME || ' - ' || N_COMMENT = col 1 || ' -
' || col 3
- // Expression 02 - ADD: N_REGIONKEY + 10 = col 1 + 10
- // Expression 03 - FILTER: N_NATIONKEY > 18 = col 3 > 18
+ public static void main(String[] args) throws SqlParseException {
projectAndFilterDataset();
}
- public static void projectAndFilterDataset() {
+ private static void projectAndFilterDataset() throws SqlParseException {
String uri = "file:///Users/data/tpch_parquet/nation.parquet";
- ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768)
- .columns(Optional.empty())
- .substraitFilter(getSubstraitExpressionFilter())
- .substraitProjection(getSubstraitExpressionProjection())
- .build();
- try (
- BufferAllocator allocator = new RootAllocator();
- DatasetFactory datasetFactory = new FileSystemDatasetFactory(
- allocator, NativeMemoryPool.getDefault(),
- FileFormat.PARQUET, uri);
- Dataset dataset = datasetFactory.finish();
- Scanner scanner = dataset.newScan(options);
- ArrowReader reader = scanner.scanBatches()
- ) {
+ ScanOptions options =
+ new ScanOptions.Builder(/*batchSize*/ 32768)
+ .columns(Optional.empty())
+ .substraitFilter(getByteBuffer(new String[]{"N_NATIONKEY >
18"}))
+ .substraitProjection(getByteBuffer(new String[]{"N_REGIONKEY +
10",
+ "N_NAME || CAST(' - ' as VARCHAR) || N_COMMENT"}))
+ .build();
+ try (BufferAllocator allocator = new RootAllocator();
+ DatasetFactory datasetFactory =
+ new FileSystemDatasetFactory(
+ allocator, NativeMemoryPool.getDefault(),
FileFormat.PARQUET, uri);
+ Dataset dataset = datasetFactory.finish();
+ Scanner scanner = dataset.newScan(options);
+ ArrowReader reader = scanner.scanBatches()) {
while (reader.loadNextBatch()) {
- System.out.println(
- reader.getVectorSchemaRoot().contentToTSVString());
+
System.out.println(reader.getVectorSchemaRoot().contentToTSVString());
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
- private static ByteBuffer getSubstraitExpressionProjection() {
- // Expression: N_REGIONKEY + 10 = col 3 + 10
- Expression.Builder selectionBuilderProjectOne =
Expression.newBuilder().
- setSelection(
- Expression.FieldReference.newBuilder().
- setDirectReference(
- Expression.ReferenceSegment.newBuilder().
- setStructField(
-
Expression.ReferenceSegment.StructField.newBuilder().setField(
- 2)
- )
- )
- );
- Expression.Builder literalBuilderProjectOne = Expression.newBuilder()
- .setLiteral(
- Expression.Literal.newBuilder().setI32(10)
- );
- io.substrait.proto.Type outputProjectOne =
TypeCreator.NULLABLE.I32.accept(
- new TypeProtoConverter(new ExtensionCollector()));
- Expression.Builder expressionBuilderProjectOne = Expression.
- newBuilder().
- setScalarFunction(
- Expression.
- ScalarFunction.
- newBuilder().
- setFunctionReference(0).
- setOutputType(outputProjectOne).
- addArguments(
- 0,
- FunctionArgument.newBuilder().setValue(
- selectionBuilderProjectOne)
- ).
- addArguments(
- 1,
- FunctionArgument.newBuilder().setValue(
- literalBuilderProjectOne)
- )
- );
- ExpressionReference.Builder expressionReferenceBuilderProjectOne =
ExpressionReference.newBuilder().
- setExpression(expressionBuilderProjectOne)
- .addOutputNames("ADD_TEN_TO_COLUMN_N_REGIONKEY");
-
- // Expression: name || name = N_NAME || "-" || N_COMMENT = col 1 ||
col 3
- Expression.Builder selectionBuilderProjectTwo =
Expression.newBuilder().
- setSelection(
- Expression.FieldReference.newBuilder().
- setDirectReference(
- Expression.ReferenceSegment.newBuilder().
- setStructField(
-
Expression.ReferenceSegment.StructField.newBuilder().setField(
- 1)
- )
- )
- );
- Expression.Builder selectionBuilderProjectTwoConcatLiteral =
Expression.newBuilder()
- .setLiteral(
- Expression.Literal.newBuilder().setString(" - ")
- );
- Expression.Builder selectionBuilderProjectOneToConcat =
Expression.newBuilder().
- setSelection(
- Expression.FieldReference.newBuilder().
- setDirectReference(
- Expression.ReferenceSegment.newBuilder().
- setStructField(
-
Expression.ReferenceSegment.StructField.newBuilder().setField(
- 3)
- )
- )
- );
- io.substrait.proto.Type outputProjectTwo =
TypeCreator.NULLABLE.STRING.accept(
- new TypeProtoConverter(new ExtensionCollector()));
- Expression.Builder expressionBuilderProjectTwo = Expression.
- newBuilder().
- setScalarFunction(
- Expression.
- ScalarFunction.
- newBuilder().
- setFunctionReference(1).
- setOutputType(outputProjectTwo).
- addArguments(
- 0,
- FunctionArgument.newBuilder().setValue(
- selectionBuilderProjectTwo)
- ).
- addArguments(
- 1,
- FunctionArgument.newBuilder().setValue(
- selectionBuilderProjectTwoConcatLiteral)
- ).
- addArguments(
- 2,
- FunctionArgument.newBuilder().setValue(
- selectionBuilderProjectOneToConcat)
- )
- );
- ExpressionReference.Builder expressionReferenceBuilderProjectTwo =
ExpressionReference.newBuilder().
- setExpression(expressionBuilderProjectTwo)
- .addOutputNames("CONCAT_COLUMNS_N_NAME_AND_N_COMMENT");
-
- List<String> columnNames = Arrays.asList("N_NATIONKEY", "N_NAME",
- "N_REGIONKEY", "N_COMMENT");
- List<Type> dataTypes = Arrays.asList(
- TypeCreator.NULLABLE.I32,
- TypeCreator.NULLABLE.STRING,
- TypeCreator.NULLABLE.I32,
- TypeCreator.NULLABLE.STRING
- );
- NamedStruct of = NamedStruct.of(
- columnNames,
- Type.Struct.builder().fields(dataTypes).nullable(false).build()
- );
- // Extensions URI
- HashMap<String, SimpleExtensionURI> extensionUris = new HashMap<>();
- extensionUris.put(
- "key-001",
- SimpleExtensionURI.newBuilder()
- .setExtensionUriAnchor(1)
- .setUri("/functions_arithmetic.yaml")
- .build()
- );
- // Extensions
- ArrayList<SimpleExtensionDeclaration> extensions = new ArrayList<>();
- SimpleExtensionDeclaration extensionFunctionAdd =
SimpleExtensionDeclaration.newBuilder()
- .setExtensionFunction(
- SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
- .setFunctionAnchor(0)
- .setName("add:i32_i32")
- .setExtensionUriReference(1))
- .build();
- SimpleExtensionDeclaration extensionFunctionGreaterThan =
SimpleExtensionDeclaration.newBuilder()
- .setExtensionFunction(
- SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
- .setFunctionAnchor(1)
- .setName("concat:vchar")
- .setExtensionUriReference(2))
- .build();
- extensions.add(extensionFunctionAdd);
- extensions.add(extensionFunctionGreaterThan);
- // Extended Expression
- ExtendedExpression.Builder extendedExpressionBuilder =
- ExtendedExpression.newBuilder().
- addReferredExpr(0,
- expressionReferenceBuilderProjectOne).
- addReferredExpr(1,
- expressionReferenceBuilderProjectTwo).
- setBaseSchema(of.toProto(new TypeProtoConverter(
- new ExtensionCollector())));
- extendedExpressionBuilder.addAllExtensionUris(extensionUris.values());
- extendedExpressionBuilder.addAllExtensions(extensions);
- ExtendedExpression extendedExpression =
extendedExpressionBuilder.build();
- byte[] extendedExpressions = Base64.getDecoder().decode(
- Base64.getEncoder().encodeToString(
- extendedExpression.toByteArray()));
- ByteBuffer substraitExpressionProjection = ByteBuffer.allocateDirect(
- extendedExpressions.length);
- substraitExpressionProjection.put(extendedExpressions);
- return substraitExpressionProjection;
- }
-
- private static ByteBuffer getSubstraitExpressionFilter() {
- // Expression: Filter: N_NATIONKEY > 18 = col 1 > 18
- Expression.Builder selectionBuilderFilterOne = Expression.newBuilder().
- setSelection(
- Expression.FieldReference.newBuilder().
- setDirectReference(
- Expression.ReferenceSegment.newBuilder().
- setStructField(
-
Expression.ReferenceSegment.StructField.newBuilder().setField(
- 0)
- )
- )
- );
- Expression.Builder literalBuilderFilterOne = Expression.newBuilder()
- .setLiteral(
- Expression.Literal.newBuilder().setI32(18)
- );
- io.substrait.proto.Type outputFilterOne =
TypeCreator.NULLABLE.BOOLEAN.accept(
- new TypeProtoConverter(new ExtensionCollector()));
- Expression.Builder expressionBuilderFilterOne = Expression.
- newBuilder().
- setScalarFunction(
- Expression.
- ScalarFunction.
- newBuilder().
- setFunctionReference(1).
- setOutputType(outputFilterOne).
- addArguments(
- 0,
- FunctionArgument.newBuilder().setValue(
- selectionBuilderFilterOne)
- ).
- addArguments(
- 1,
- FunctionArgument.newBuilder().setValue(
- literalBuilderFilterOne)
- )
- );
- ExpressionReference.Builder expressionReferenceBuilderFilterOne =
ExpressionReference.newBuilder().
- setExpression(expressionBuilderFilterOne)
- .addOutputNames("COLUMN_N_NATIONKEY_GREATER_THAN_18");
-
- List<String> columnNames = Arrays.asList("N_NATIONKEY", "N_NAME",
- "N_REGIONKEY", "N_COMMENT");
- List<Type> dataTypes = Arrays.asList(
- TypeCreator.NULLABLE.I32,
- TypeCreator.NULLABLE.STRING,
- TypeCreator.NULLABLE.I32,
- TypeCreator.NULLABLE.STRING
- );
- NamedStruct of = NamedStruct.of(
- columnNames,
- Type.Struct.builder().fields(dataTypes).nullable(false).build()
- );
- // Extensions URI
- HashMap<String, SimpleExtensionURI> extensionUris = new HashMap<>();
- extensionUris.put(
- "key-001",
- SimpleExtensionURI.newBuilder()
- .setExtensionUriAnchor(1)
- .setUri("/functions_comparison.yaml")
- .build()
- );
- // Extensions
- ArrayList<SimpleExtensionDeclaration> extensions = new ArrayList<>();
- SimpleExtensionDeclaration extensionFunctionLowerThan =
SimpleExtensionDeclaration.newBuilder()
- .setExtensionFunction(
- SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
- .setFunctionAnchor(1)
- .setName("gt:any_any")
- .setExtensionUriReference(1))
- .build();
- extensions.add(extensionFunctionLowerThan);
- // Extended Expression
- ExtendedExpression.Builder extendedExpressionBuilder =
- ExtendedExpression.newBuilder().
- addReferredExpr(0,
- expressionReferenceBuilderFilterOne).
- setBaseSchema(of.toProto(new TypeProtoConverter(
- new ExtensionCollector())));
- extendedExpressionBuilder.addAllExtensionUris(extensionUris.values());
- extendedExpressionBuilder.addAllExtensions(extensions);
- ExtendedExpression extendedExpression =
extendedExpressionBuilder.build();
- byte[] extendedExpressions = Base64.getDecoder().decode(
- Base64.getEncoder().encodeToString(
- extendedExpression.toByteArray()));
- ByteBuffer substraitExpressionFilter = ByteBuffer.allocateDirect(
- extendedExpressions.length);
- substraitExpressionFilter.put(extendedExpressions);
- return substraitExpressionFilter;
+ private static ByteBuffer getByteBuffer(String[] sqlExpression) throws
SqlParseException {
+ String schema =
+ "CREATE TABLE NATION (N_NATIONKEY INT NOT NULL, N_NAME VARCHAR, "
+ + "N_REGIONKEY INT NOT NULL, N_COMMENT VARCHAR)";
+ SqlExpressionToSubstrait expressionToSubstrait = new
SqlExpressionToSubstrait();
+ ExtendedExpression expression =
+ expressionToSubstrait.convert(sqlExpression,
ImmutableList.of(schema));
+ byte[] expressionToByte =
+
Base64.getDecoder().decode(Base64.getEncoder().encodeToString(expression.toByteArray()));
+ ByteBuffer byteBuffer =
ByteBuffer.allocateDirect(expressionToByte.length);
+ byteBuffer.put(expressionToByte);
+ return byteBuffer;
}
}
.. code-block:: text
- ADD_TEN_TO_COLUMN_N_REGIONKEY CONCAT_COLUMNS_N_NAME_AND_N_COMMENT
+ column-1 column-2
Review Comment:
@davisusanibar I guess this column name changes occur from the Substrait
end?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]