wangshuo128 opened a new pull request #25902: [SPARK-29213][SQL] Make it consistent when get notnull output and generate… URL: https://github.com/apache/spark/pull/25902 ### What changes were proposed in this pull request? Currently the behavior of getting output and generating null checks in `FilterExec` is different. Thus some nullable attribute could be treated as not nullable by mistake. In `FilterExec.ouput`, an attribute is marked as nullable or not by finding its `exprId` in notNullAttributes: ``` a.nullable && notNullAttributes.contains(a.exprId) ``` But in `FilterExec.doConsume`, a `nullCheck` is generated or not for an attribute is decided by whether there is semantic equal not null predicate: ``` val nullChecks = c.references.map { r => val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} if (idx != -1 && !generatedIsNotNullChecks(idx)) { generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. genPredicate(notNullPreds(idx), input, child.output) } else { "" } }.mkString("\n").trim ``` NPE will happen when run the SQL below: ``` sql("create table table1(x string)") sql("create table table2(x bigint)") sql("create table table3(x string)") sql("insert into table2 select null as x") sql( """ |select t1.x |from ( | select x from table1) t1 |left join ( | select x from ( | select x from table2 | union all | select substr(x,5) x from table3 | ) a | where length(x)>0 |) t3 |on t1.x=t3.x """.stripMargin).collect() ``` NPE Exception: ``` java.lang.NullPointerException at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(generated.java:40) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:726) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458) at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:135) at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:94) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52) at org.apache.spark.scheduler.Task.run(Task.scala:127) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:449) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:452) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` the generated code: ``` == Subtree 4 / 5 == *(2) Project [cast(x#7L as string) AS x#9] +- *(2) Filter ((length(cast(x#7L as string)) > 0) AND isnotnull(cast(x#7L as string))) +- Scan hive default.table2 [x#7L], HiveTableRelation `default`.`table2`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#7L] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage2(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=2 /* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator inputadapter_input_0; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 011 */ /* 012 */ public GeneratedIteratorForCodegenStage2(Object[] references) { /* 013 */ this.references = references; /* 014 */ } /* 015 */ /* 016 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 017 */ partitionIndex = index; /* 018 */ this.inputs = inputs; /* 019 */ inputadapter_input_0 = inputs[0]; /* 020 */ filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0); /* 021 */ filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 022 */ /* 023 */ } /* 024 */ /* 025 */ protected void processNext() throws java.io.IOException { /* 026 */ while ( inputadapter_input_0.hasNext()) { /* 027 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next(); /* 028 */ /* 029 */ do { /* 030 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0); /* 031 */ long inputadapter_value_0 = inputadapter_isNull_0 ? /* 032 */ -1L : (inputadapter_row_0.getLong(0)); /* 033 */ /* 034 */ boolean filter_isNull_2 = inputadapter_isNull_0; /* 035 */ UTF8String filter_value_2 = null; /* 036 */ if (!inputadapter_isNull_0) { /* 037 */ filter_value_2 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 038 */ } /* 039 */ int filter_value_1 = -1; /* 040 */ filter_value_1 = (filter_value_2).numChars(); /* 041 */ /* 042 */ boolean filter_value_0 = false; /* 043 */ filter_value_0 = filter_value_1 > 0; /* 044 */ if (!filter_value_0) continue; /* 045 */ /* 046 */ boolean filter_isNull_6 = inputadapter_isNull_0; /* 047 */ UTF8String filter_value_6 = null; /* 048 */ if (!inputadapter_isNull_0) { /* 049 */ filter_value_6 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 050 */ } /* 051 */ if (!(!filter_isNull_6)) continue; /* 052 */ /* 053 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 054 */ /* 055 */ boolean project_isNull_0 = false; /* 056 */ UTF8String project_value_0 = null; /* 057 */ if (!false) { /* 058 */ project_value_0 = UTF8String.fromString(String.valueOf(inputadapter_value_0)); /* 059 */ } /* 060 */ filter_mutableStateArray_0[1].reset(); /* 061 */ /* 062 */ filter_mutableStateArray_0[1].zeroOutNullBytes(); /* 063 */ /* 064 */ if (project_isNull_0) { /* 065 */ filter_mutableStateArray_0[1].setNullAt(0); /* 066 */ } else { /* 067 */ filter_mutableStateArray_0[1].write(0, project_value_0); /* 068 */ } /* 069 */ append((filter_mutableStateArray_0[1].getRow())); /* 070 */ /* 071 */ } while(false); /* 072 */ if (shouldStop()) return; /* 073 */ } /* 074 */ } /* 075 */ /* 076 */ } ``` This PR proposes to use semantic comparison both in `FilterExec.output` and `FilterExec.doConsume` for nullable attribute. With this PR, the generated code snippet is below: ``` == Subtree 2 / 5 == *(3) Project [substring(x#8, 5, 2147483647) AS x#5] +- *(3) Filter ((length(substring(x#8, 5, 2147483647)) > 0) AND isnotnull(substring(x#8, 5, 2147483647))) +- Scan hive default.table3 [x#8], HiveTableRelation `default`.`table3`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [x#8] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIteratorForCodegenStage3(references); /* 003 */ } /* 004 */ /* 005 */ // codegenStageId=3 /* 006 */ final class GeneratedIteratorForCodegenStage3 extends org.apache.spark.sql.execution.BufferedRowIterator { /* 007 */ private Object[] references; /* 008 */ private scala.collection.Iterator[] inputs; /* 009 */ private scala.collection.Iterator inputadapter_input_0; /* 010 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] filter_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2]; /* 011 */ /* 012 */ public GeneratedIteratorForCodegenStage3(Object[] references) { /* 013 */ this.references = references; /* 014 */ } /* 015 */ /* 016 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 017 */ partitionIndex = index; /* 018 */ this.inputs = inputs; /* 019 */ inputadapter_input_0 = inputs[0]; /* 020 */ filter_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 021 */ filter_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 32); /* 022 */ /* 023 */ } /* 024 */ /* 025 */ protected void processNext() throws java.io.IOException { /* 026 */ while ( inputadapter_input_0.hasNext()) { /* 027 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next(); /* 028 */ /* 029 */ do { /* 030 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0); /* 031 */ UTF8String inputadapter_value_0 = inputadapter_isNull_0 ? /* 032 */ null : (inputadapter_row_0.getUTF8String(0)); /* 033 */ /* 034 */ boolean filter_isNull_0 = true; /* 035 */ boolean filter_value_0 = false; /* 036 */ boolean filter_isNull_2 = true; /* 037 */ UTF8String filter_value_2 = null; /* 038 */ /* 039 */ if (!inputadapter_isNull_0) { /* 040 */ filter_isNull_2 = false; // resultCode could change nullability. /* 041 */ filter_value_2 = inputadapter_value_0.substringSQL(5, 2147483647); /* 042 */ /* 043 */ } /* 044 */ boolean filter_isNull_1 = filter_isNull_2; /* 045 */ int filter_value_1 = -1; /* 046 */ /* 047 */ if (!filter_isNull_2) { /* 048 */ filter_value_1 = (filter_value_2).numChars(); /* 049 */ } /* 050 */ if (!filter_isNull_1) { /* 051 */ filter_isNull_0 = false; // resultCode could change nullability. /* 052 */ filter_value_0 = filter_value_1 > 0; /* 053 */ /* 054 */ } /* 055 */ if (filter_isNull_0 || !filter_value_0) continue; /* 056 */ boolean filter_isNull_8 = true; /* 057 */ UTF8String filter_value_8 = null; /* 058 */ /* 059 */ if (!inputadapter_isNull_0) { /* 060 */ filter_isNull_8 = false; // resultCode could change nullability. /* 061 */ filter_value_8 = inputadapter_value_0.substringSQL(5, 2147483647); /* 062 */ /* 063 */ } /* 064 */ if (!(!filter_isNull_8)) continue; /* 065 */ /* 066 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 067 */ /* 068 */ boolean project_isNull_0 = true; /* 069 */ UTF8String project_value_0 = null; /* 070 */ /* 071 */ if (!inputadapter_isNull_0) { /* 072 */ project_isNull_0 = false; // resultCode could change nullability. /* 073 */ project_value_0 = inputadapter_value_0.substringSQL(5, 2147483647); /* 074 */ /* 075 */ } /* 076 */ filter_mutableStateArray_0[1].reset(); /* 077 */ /* 078 */ filter_mutableStateArray_0[1].zeroOutNullBytes(); /* 079 */ /* 080 */ if (project_isNull_0) { /* 081 */ filter_mutableStateArray_0[1].setNullAt(0); /* 082 */ } else { /* 083 */ filter_mutableStateArray_0[1].write(0, project_value_0); /* 084 */ } /* 085 */ append((filter_mutableStateArray_0[1].getRow())); /* 086 */ /* 087 */ } while(false); /* 088 */ if (shouldStop()) return; /* 089 */ } /* 090 */ } /* 091 */ /* 092 */ } ``` ### How was this patch tested? new UT
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
