wangshuo128 opened a new pull request #25902: [SPARK-29213][SQL] Make it 
consistent when get notnull output and generate…
URL: https://github.com/apache/spark/pull/25902
 
 
   
   ### What changes were proposed in this pull request?
   Currently the behavior of getting output and generating null checks in 
`FilterExec` is different. Thus some nullable attribute could be treated as not 
nullable by mistake.
   
   In `FilterExec.ouput`, an attribute is marked as nullable or not by finding 
its `exprId` in notNullAttributes:
   ```
   a.nullable && notNullAttributes.contains(a.exprId)
   ```
   But in `FilterExec.doConsume`,  a `nullCheck` is generated or not for an 
attribute is decided by whether there is semantic equal not null predicate:
   ```
         val nullChecks = c.references.map { r =>
           val idx = notNullPreds.indexWhere { n => 
n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
           if (idx != -1 && !generatedIsNotNullChecks(idx)) {
             generatedIsNotNullChecks(idx) = true
             // Use the child's output. The nullability is what the child 
produced.
             genPredicate(notNullPreds(idx), input, child.output)
           } else {
             ""
           }
         }.mkString("\n").trim
   ```
   NPE will happen when run the SQL below:
   ```
   sql("create table table1(x string)")
   sql("create table table2(x bigint)")
   sql("create table table3(x string)")
   sql("insert into table2 select null as x")
   sql(
     """
       |select t1.x
       |from (
       |    select x from table1) t1
       |left join (
       |    select x from (
       |        select x from table2
       |        union all
       |        select substr(x,5) x from table3
       |    ) a
       |    where length(x)>0
       |) t3
       |on t1.x=t3.x
     """.stripMargin).collect()
   ```
   NPE Exception:
   ```
   java.lang.NullPointerException
       at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(generated.java:40)
       at 
org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
       at 
org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:726)
       at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
       at 
org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:135)
       at 
org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
       at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:94)
       at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
       at org.apache.spark.scheduler.Task.run(Task.scala:127)
       at 
org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:449)
       at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
       at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:452)
       at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
       at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
       at java.lang.Thread.run(Thread.java:748)
   ```
   the generated code:
   ```
   == Subtree 4 / 5 ==
   *(2) Project [cast(x#7L as string) AS x#9]
   +- *(2) Filter ((length(cast(x#7L as string)) > 0) AND isnotnull(cast(x#7L 
as string)))
      +- Scan hive default.table2 [x#7L], HiveTableRelation `default`.`table2`, 
org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#7L]
   
   Generated code:
   /* 001 */ public Object generate(Object[] references) {
   /* 002 */   return new GeneratedIteratorForCodegenStage2(references);
   /* 003 */ }
   /* 004 */
   /* 005 */ // codegenStageId=2
   /* 006 */ final class GeneratedIteratorForCodegenStage2 extends 
org.apache.spark.sql.execution.BufferedRowIterator {
   /* 007 */   private Object[] references;
   /* 008 */   private scala.collection.Iterator[] inputs;
   /* 009 */   private scala.collection.Iterator inputadapter_input_0;
   /* 010 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] 
filter_mutableStateArray_0 = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
   /* 011 */
   /* 012 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
   /* 013 */     this.references = references;
   /* 014 */   }
   /* 015 */
   /* 016 */   public void init(int index, scala.collection.Iterator[] inputs) {
   /* 017 */     partitionIndex = index;
   /* 018 */     this.inputs = inputs;
   /* 019 */     inputadapter_input_0 = inputs[0];
   /* 020 */     filter_mutableStateArray_0[0] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
   /* 021 */     filter_mutableStateArray_0[1] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
   /* 022 */
   /* 023 */   }
   /* 024 */
   /* 025 */   protected void processNext() throws java.io.IOException {
   /* 026 */     while ( inputadapter_input_0.hasNext()) {
   /* 027 */       InternalRow inputadapter_row_0 = (InternalRow) 
inputadapter_input_0.next();
   /* 028 */
   /* 029 */       do {
   /* 030 */         boolean inputadapter_isNull_0 = 
inputadapter_row_0.isNullAt(0);
   /* 031 */         long inputadapter_value_0 = inputadapter_isNull_0 ?
   /* 032 */         -1L : (inputadapter_row_0.getLong(0));
   /* 033 */
   /* 034 */         boolean filter_isNull_2 = inputadapter_isNull_0;
   /* 035 */         UTF8String filter_value_2 = null;
   /* 036 */         if (!inputadapter_isNull_0) {
   /* 037 */           filter_value_2 = 
UTF8String.fromString(String.valueOf(inputadapter_value_0));
   /* 038 */         }
   /* 039 */         int filter_value_1 = -1;
   /* 040 */         filter_value_1 = (filter_value_2).numChars();
   /* 041 */
   /* 042 */         boolean filter_value_0 = false;
   /* 043 */         filter_value_0 = filter_value_1 > 0;
   /* 044 */         if (!filter_value_0) continue;
   /* 045 */
   /* 046 */         boolean filter_isNull_6 = inputadapter_isNull_0;
   /* 047 */         UTF8String filter_value_6 = null;
   /* 048 */         if (!inputadapter_isNull_0) {
   /* 049 */           filter_value_6 = 
UTF8String.fromString(String.valueOf(inputadapter_value_0));
   /* 050 */         }
   /* 051 */         if (!(!filter_isNull_6)) continue;
   /* 052 */
   /* 053 */         ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[0] /* numOutputRows */).add(1);
   /* 054 */
   /* 055 */         boolean project_isNull_0 = false;
   /* 056 */         UTF8String project_value_0 = null;
   /* 057 */         if (!false) {
   /* 058 */           project_value_0 = 
UTF8String.fromString(String.valueOf(inputadapter_value_0));
   /* 059 */         }
   /* 060 */         filter_mutableStateArray_0[1].reset();
   /* 061 */
   /* 062 */         filter_mutableStateArray_0[1].zeroOutNullBytes();
   /* 063 */
   /* 064 */         if (project_isNull_0) {
   /* 065 */           filter_mutableStateArray_0[1].setNullAt(0);
   /* 066 */         } else {
   /* 067 */           filter_mutableStateArray_0[1].write(0, project_value_0);
   /* 068 */         }
   /* 069 */         append((filter_mutableStateArray_0[1].getRow()));
   /* 070 */
   /* 071 */       } while(false);
   /* 072 */       if (shouldStop()) return;
   /* 073 */     }
   /* 074 */   }
   /* 075 */
   /* 076 */ }
   
   ```
   
   This PR proposes to use semantic comparison both in `FilterExec.output` and 
`FilterExec.doConsume` for nullable attribute.
   
   With this PR, the generated code snippet is below:
   ```
   == Subtree 2 / 5 ==
   *(3) Project [substring(x#8, 5, 2147483647) AS x#5]
   +- *(3) Filter ((length(substring(x#8, 5, 2147483647)) > 0) AND 
isnotnull(substring(x#8, 5, 2147483647)))
      +- Scan hive default.table3 [x#8], HiveTableRelation `default`.`table3`, 
org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#8]
   
   Generated code:
   /* 001 */ public Object generate(Object[] references) {
   /* 002 */   return new GeneratedIteratorForCodegenStage3(references);
   /* 003 */ }
   /* 004 */
   /* 005 */ // codegenStageId=3
   /* 006 */ final class GeneratedIteratorForCodegenStage3 extends 
org.apache.spark.sql.execution.BufferedRowIterator {
   /* 007 */   private Object[] references;
   /* 008 */   private scala.collection.Iterator[] inputs;
   /* 009 */   private scala.collection.Iterator inputadapter_input_0;
   /* 010 */   private 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] 
filter_mutableStateArray_0 = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
   /* 011 */
   /* 012 */   public GeneratedIteratorForCodegenStage3(Object[] references) {
   /* 013 */     this.references = references;
   /* 014 */   }
   /* 015 */
   /* 016 */   public void init(int index, scala.collection.Iterator[] inputs) {
   /* 017 */     partitionIndex = index;
   /* 018 */     this.inputs = inputs;
   /* 019 */     inputadapter_input_0 = inputs[0];
   /* 020 */     filter_mutableStateArray_0[0] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
   /* 021 */     filter_mutableStateArray_0[1] = new 
org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32);
   /* 022 */
   /* 023 */   }
   /* 024 */
   /* 025 */   protected void processNext() throws java.io.IOException {
   /* 026 */     while ( inputadapter_input_0.hasNext()) {
   /* 027 */       InternalRow inputadapter_row_0 = (InternalRow) 
inputadapter_input_0.next();
   /* 028 */
   /* 029 */       do {
   /* 030 */         boolean inputadapter_isNull_0 = 
inputadapter_row_0.isNullAt(0);
   /* 031 */         UTF8String inputadapter_value_0 = inputadapter_isNull_0 ?
   /* 032 */         null : (inputadapter_row_0.getUTF8String(0));
   /* 033 */
   /* 034 */         boolean filter_isNull_0 = true;
   /* 035 */         boolean filter_value_0 = false;
   /* 036 */         boolean filter_isNull_2 = true;
   /* 037 */         UTF8String filter_value_2 = null;
   /* 038 */
   /* 039 */         if (!inputadapter_isNull_0) {
   /* 040 */           filter_isNull_2 = false; // resultCode could change 
nullability.
   /* 041 */           filter_value_2 = inputadapter_value_0.substringSQL(5, 
2147483647);
   /* 042 */
   /* 043 */         }
   /* 044 */         boolean filter_isNull_1 = filter_isNull_2;
   /* 045 */         int filter_value_1 = -1;
   /* 046 */
   /* 047 */         if (!filter_isNull_2) {
   /* 048 */           filter_value_1 = (filter_value_2).numChars();
   /* 049 */         }
   /* 050 */         if (!filter_isNull_1) {
   /* 051 */           filter_isNull_0 = false; // resultCode could change 
nullability.
   /* 052 */           filter_value_0 = filter_value_1 > 0;
   /* 053 */
   /* 054 */         }
   /* 055 */         if (filter_isNull_0 || !filter_value_0) continue;
   /* 056 */         boolean filter_isNull_8 = true;
   /* 057 */         UTF8String filter_value_8 = null;
   /* 058 */
   /* 059 */         if (!inputadapter_isNull_0) {
   /* 060 */           filter_isNull_8 = false; // resultCode could change 
nullability.
   /* 061 */           filter_value_8 = inputadapter_value_0.substringSQL(5, 
2147483647);
   /* 062 */
   /* 063 */         }
   /* 064 */         if (!(!filter_isNull_8)) continue;
   /* 065 */
   /* 066 */         ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[0] /* numOutputRows */).add(1);
   /* 067 */
   /* 068 */         boolean project_isNull_0 = true;
   /* 069 */         UTF8String project_value_0 = null;
   /* 070 */
   /* 071 */         if (!inputadapter_isNull_0) {
   /* 072 */           project_isNull_0 = false; // resultCode could change 
nullability.
   /* 073 */           project_value_0 = inputadapter_value_0.substringSQL(5, 
2147483647);
   /* 074 */
   /* 075 */         }
   /* 076 */         filter_mutableStateArray_0[1].reset();
   /* 077 */
   /* 078 */         filter_mutableStateArray_0[1].zeroOutNullBytes();
   /* 079 */
   /* 080 */         if (project_isNull_0) {
   /* 081 */           filter_mutableStateArray_0[1].setNullAt(0);
   /* 082 */         } else {
   /* 083 */           filter_mutableStateArray_0[1].write(0, project_value_0);
   /* 084 */         }
   /* 085 */         append((filter_mutableStateArray_0[1].getRow()));
   /* 086 */
   /* 087 */       } while(false);
   /* 088 */       if (shouldStop()) return;
   /* 089 */     }
   /* 090 */   }
   /* 091 */
   /* 092 */ }
   ```
   
   ### How was this patch tested?
   new UT
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to