zstan commented on code in PR #3608:
URL: https://github.com/apache/ignite-3/pull/3608#discussion_r1576276053


##########
modules/sql-engine/src/main/java/org/apache/ignite/internal/sql/engine/exec/rel/HashJoinNode.java:
##########
@@ -0,0 +1,623 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.internal.sql.engine.exec.rel;
+
+import static 
org.apache.ignite.internal.sql.engine.util.TypeUtils.rowSchemaFromRelTypes;
+
+import it.unimi.dsi.fastutil.objects.Object2ObjectOpenHashMap;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.rel.core.JoinInfo;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.ignite.internal.sql.engine.exec.ExecutionContext;
+import org.apache.ignite.internal.sql.engine.exec.RowHandler;
+import org.apache.ignite.internal.sql.engine.exec.row.RowSchema;
+import org.jetbrains.annotations.Nullable;
+
+/** HashJoin implementor. */
+public abstract class HashJoinNode<RowT> extends 
AbstractRightMaterializedJoinNode<RowT> {
+    Map<Object, Object> hashStore = new Object2ObjectOpenHashMap<>();
+    protected final RowHandler<RowT> handler;
+
+    final Collection<Integer> leftJoinPositions;
+    private final Collection<Integer> rightJoinPositions;
+
+    final boolean touchResults;
+
+    Iterator<RowT> rightIt = Collections.emptyIterator();
+
+    private HashJoinNode(ExecutionContext<RowT> ctx, JoinInfo joinInfo, 
boolean touch) {
+        super(ctx);
+
+        handler = ctx.rowHandler();
+        touchResults = touch;
+
+        leftJoinPositions = joinInfo.leftKeys.toIntegerList();
+        rightJoinPositions = joinInfo.rightKeys.toIntegerList();
+    }
+
+    @Override
+    protected void rewindInternal() {
+        rightIt = Collections.emptyIterator();
+
+        hashStore.clear();
+
+        super.rewindInternal();
+    }
+
+    /** Supplied algorithm implementation. */
+    public static <RowT> HashJoinNode<RowT> create(ExecutionContext<RowT> ctx, 
RelDataType outputRowType,
+            RelDataType leftRowType, RelDataType rightRowType, JoinRelType 
joinType, JoinInfo joinInfo) {
+
+        switch (joinType) {
+            case INNER:
+                return new InnerHashJoin<>(ctx, joinInfo);
+
+            case LEFT: {
+                RowSchema rightRowSchema = 
rowSchemaFromRelTypes(RelOptUtil.getFieldTypeList(rightRowType));
+                RowHandler.RowFactory<RowT> rightRowFactory = 
ctx.rowHandler().factory(rightRowSchema);
+
+                return new LeftHashJoin<>(ctx, rightRowFactory, joinInfo);
+            }
+
+            case RIGHT: {
+                RowSchema leftRowSchema = 
rowSchemaFromRelTypes(RelOptUtil.getFieldTypeList(leftRowType));
+                RowHandler.RowFactory<RowT> leftRowFactory = 
ctx.rowHandler().factory(leftRowSchema);
+
+                return new RightHashJoin<>(ctx, leftRowFactory, joinInfo);
+            }
+
+            case FULL: {
+                RowSchema leftRowSchema = 
rowSchemaFromRelTypes(RelOptUtil.getFieldTypeList(leftRowType));
+                RowSchema rightRowSchema = 
rowSchemaFromRelTypes(RelOptUtil.getFieldTypeList(rightRowType));
+                RowHandler.RowFactory<RowT> leftRowFactory = 
ctx.rowHandler().factory(leftRowSchema);
+                RowHandler.RowFactory<RowT> rightRowFactory = 
ctx.rowHandler().factory(rightRowSchema);
+
+                return new FullOuterHashJoin<>(ctx, leftRowFactory, 
rightRowFactory, joinInfo);
+            }
+
+            case SEMI:
+                return new SemiHashJoin<>(ctx, joinInfo);
+
+            case ANTI:
+                return new AntiHashJoin<>(ctx, joinInfo);
+
+            default:
+                throw new IllegalStateException("Join type \"" + joinType + 
"\" is not supported yet");
+        }
+    }
+
+    private static class InnerHashJoin<RowT> extends HashJoinNode<RowT> {
+        private InnerHashJoin(ExecutionContext<RowT> ctx, JoinInfo joinInfo) {
+            super(ctx, joinInfo, false);
+        }
+
+        @Override
+        protected void join() throws Exception {
+            if (waitingRight == NOT_WAITING) {
+                inLoop = true;
+                try {
+                    while (requested > 0 && (left != null || 
!leftInBuf.isEmpty())) {
+                        if (!rightIt.hasNext()) {
+                            left = leftInBuf.remove();
+
+                            Collection<RowT> rightRows = lookup(left, handler, 
hashStore, leftJoinPositions, touchResults);
+
+                            rightIt = rightRows.iterator();
+                        }
+
+                        if (rightIt.hasNext()) {
+                            while (rightIt.hasNext()) {
+                                checkState();
+
+                                RowT right = rightIt.next();
+
+                                --requested;
+
+                                RowT row = handler.concat(left, right);
+                                downstream().push(row);
+
+                                if (requested == 0) {
+                                    break;
+                                }
+                            }
+                        }
+
+                        if (!rightIt.hasNext()) {
+                            left = null;
+                        }
+                    }
+                } finally {
+                    inLoop = false;
+                }
+            }
+
+            getMoreOrEnd();
+        }
+    }
+
+    private static class LeftHashJoin<RowT> extends HashJoinNode<RowT> {
+        /** Right row factory. */
+        private final RowHandler.RowFactory<RowT> rightRowFactory;
+
+        private LeftHashJoin(
+                ExecutionContext<RowT> ctx,
+                RowHandler.RowFactory<RowT> rightRowFactory,
+                JoinInfo joinInfo
+        ) {
+            super(ctx, joinInfo, false);
+
+            this.rightRowFactory = rightRowFactory;
+        }
+
+        /** {@inheritDoc} */
+        @Override
+        protected void join() throws Exception {
+            if (waitingRight == NOT_WAITING) {
+                inLoop = true;
+                try {
+                    while (requested > 0 && (left != null || 
!leftInBuf.isEmpty())) {
+                        checkState();
+
+                        if (!rightIt.hasNext()) {
+                            left = leftInBuf.remove();
+
+                            Collection<RowT> rightRows = lookup(left, handler, 
hashStore, leftJoinPositions, touchResults);
+
+                            if (rightRows.isEmpty()) {
+                                requested--;
+                                downstream().push(handler.concat(left, 
rightRowFactory.create()));
+                            }
+
+                            rightIt = rightRows.iterator();
+                        }
+
+                        if (rightIt.hasNext()) {
+                            while (rightIt.hasNext()) {
+                                checkState();
+
+                                RowT right = rightIt.next();
+
+                                --requested;
+
+                                RowT row = handler.concat(left, right);
+                                downstream().push(row);
+
+                                if (requested == 0) {
+                                    break;
+                                }
+                            }
+                        }
+
+                        if (!rightIt.hasNext()) {
+                            left = null;
+                        }
+                    }
+                } finally {
+                    inLoop = false;
+                }
+            }
+
+            getMoreOrEnd();
+        }
+    }
+
+    private static class RightHashJoin<RowT> extends HashJoinNode<RowT> {
+        /** Left row factory. */
+        private final RowHandler.RowFactory<RowT> leftRowFactory;
+
+        private RightHashJoin(
+                ExecutionContext<RowT> ctx,
+                RowHandler.RowFactory<RowT> leftRowFactory,
+                JoinInfo joinInfo
+        ) {
+            super(ctx, joinInfo, true);
+
+            this.leftRowFactory = leftRowFactory;
+        }
+
+        @Override
+        protected void rewindInternal() {
+            HashJoinNode.resetTouched(hashStore);
+
+            super.rewindInternal();
+        }
+
+        @Override
+        protected void join() throws Exception {
+            if (waitingRight == NOT_WAITING) {
+                inLoop = true;
+                try {
+                    while (requested > 0 && (left != null || 
!leftInBuf.isEmpty())) {
+                        checkState();
+
+                        if (!rightIt.hasNext()) {
+                            left = leftInBuf.remove();
+
+                            Collection<RowT> rightRows = lookup(left, handler, 
hashStore, leftJoinPositions, touchResults);
+
+                            rightIt = rightRows.iterator();
+                        }
+
+                        if (rightIt.hasNext()) {
+                            while (rightIt.hasNext()) {
+                                checkState();
+
+                                RowT right = rightIt.next();
+
+                                --requested;
+
+                                RowT row = handler.concat(left, right);
+                                downstream().push(row);
+
+                                if (requested == 0) {
+                                    break;
+                                }
+                            }
+                        }
+
+                        if (!rightIt.hasNext()) {
+                            left = null;
+                        }
+                    }
+                } finally {
+                    inLoop = false;
+                }
+            }
+
+            if (left == null && leftInBuf.isEmpty() && waitingLeft == 
NOT_WAITING && waitingRight == NOT_WAITING && requested > 0) {
+                inLoop = true;
+                try {
+                    if (!rightIt.hasNext()) {
+                        List<RowT> res = getUntouched(hashStore, null);
+                        rightIt = res.iterator();
+                    }
+
+                    RowT emptyLeft = leftRowFactory.create();
+
+                    while (rightIt.hasNext()) {
+                        checkState();
+                        RowT right = rightIt.next();
+                        RowT row = handler.concat(emptyLeft, right);
+                        --requested;
+
+                        downstream().push(row);
+
+                        if (requested == 0) {
+                            break;
+                        }
+                    }
+                } finally {
+                    inLoop = false;
+                }
+            }
+
+            getMoreOrEnd();
+        }
+    }
+
+    private static class FullOuterHashJoin<RowT> extends HashJoinNode<RowT> {
+        /** Left row factory. */
+        private final RowHandler.RowFactory<RowT> leftRowFactory;
+
+        /** Right row factory. */
+        private final RowHandler.RowFactory<RowT> rightRowFactory;
+
+        private FullOuterHashJoin(
+                ExecutionContext<RowT> ctx,
+                RowHandler.RowFactory<RowT> leftRowFactory,
+                RowHandler.RowFactory<RowT> rightRowFactory,
+                JoinInfo joinInfo
+        ) {
+            super(ctx, joinInfo, true);
+
+            this.leftRowFactory = leftRowFactory;
+            this.rightRowFactory = rightRowFactory;
+        }
+
+        @Override
+        protected void rewindInternal() {
+            HashJoinNode.resetTouched(hashStore);
+
+            super.rewindInternal();
+        }
+
+        /** {@inheritDoc} */
+        @Override
+        protected void join() throws Exception {
+            if (waitingRight == NOT_WAITING) {
+                inLoop = true;
+                try {
+                    while (requested > 0 && (left != null || 
!leftInBuf.isEmpty())) {
+                        checkState();
+
+                        if (!rightIt.hasNext()) {
+                            left = leftInBuf.remove();
+
+                            Collection<RowT> rightRows = lookup(left, handler, 
hashStore, leftJoinPositions, touchResults);
+
+                            if (rightRows.isEmpty()) {
+                                requested--;
+                                downstream().push(handler.concat(left, 
rightRowFactory.create()));
+                            }
+
+                            rightIt = rightRows.iterator();
+                        }
+
+                        if (rightIt.hasNext()) {
+                            while (rightIt.hasNext()) {
+                                checkState();
+
+                                RowT right = rightIt.next();
+
+                                --requested;
+
+                                RowT row = handler.concat(left, right);
+                                downstream().push(row);
+
+                                if (requested == 0) {
+                                    break;
+                                }
+                            }
+                        }
+
+                        if (!rightIt.hasNext()) {
+                            left = null;
+                        }
+                    }
+                } finally {
+                    inLoop = false;
+                }
+            }
+
+            if (left == null && !rightIt.hasNext() && leftInBuf.isEmpty() && 
waitingLeft == NOT_WAITING
+                    && waitingRight == NOT_WAITING && requested > 0) {
+                inLoop = true;
+                try {
+                    if (!rightIt.hasNext()) {
+                        List<RowT> res = getUntouched(hashStore, null);
+                        rightIt = res.iterator();
+                    }
+
+                    RowT emptyLeft = leftRowFactory.create();
+
+                    while (rightIt.hasNext()) {
+                        checkState();
+                        RowT right = rightIt.next();
+                        RowT row = handler.concat(emptyLeft, right);
+                        --requested;
+
+                        downstream().push(row);
+
+                        if (requested == 0) {
+                            break;
+                        }
+                    }
+                } finally {
+                    inLoop = false;
+                }
+            }
+
+            getMoreOrEnd();
+        }
+    }
+
+    private static class SemiHashJoin<RowT> extends HashJoinNode<RowT> {
+        private SemiHashJoin(ExecutionContext<RowT> ctx, JoinInfo joinInfo) {
+            super(ctx, joinInfo, false);
+        }
+
+        /** {@inheritDoc} */
+        @Override
+        protected void join() throws Exception {
+            if (waitingRight == NOT_WAITING) {
+                inLoop = true;
+                try {
+                    while (requested > 0 && (left != null || 
!leftInBuf.isEmpty())) {
+                        checkState();
+
+                        left = leftInBuf.remove();
+
+                        Collection<RowT> rightRows = lookup(left, handler, 
hashStore, leftJoinPositions, touchResults);
+
+                        if (!rightRows.isEmpty()) {
+                            requested--;
+
+                            downstream().push(left);
+
+                            if (requested == 0) {
+                                break;
+                            }
+                        }
+
+                        left = null;
+                    }
+                } finally {
+                    inLoop = false;
+                }
+            }
+
+            getMoreOrEnd();
+        }
+    }
+
+    private static class AntiHashJoin<RowT> extends HashJoinNode<RowT> {
+        private AntiHashJoin(ExecutionContext<RowT> ctx, JoinInfo joinInfo) {
+            super(ctx, joinInfo, false);
+        }
+
+        /** {@inheritDoc} */
+        @Override
+        protected void join() throws Exception {
+            if (waitingRight == NOT_WAITING) {
+                inLoop = true;
+                try {
+                    while (requested > 0 && (left != null || 
!leftInBuf.isEmpty())) {
+                        checkState();
+
+                        left = leftInBuf.remove();
+
+                        Collection<RowT> rightRows = lookup(left, handler, 
hashStore, leftJoinPositions, touchResults);
+
+                        if (rightRows.isEmpty()) {
+                            requested--;
+
+                            downstream().push(left);
+
+                            if (requested == 0) {
+                                break;
+                            }
+                        }
+
+                        left = null;
+                    }
+                } finally {
+                    inLoop = false;
+                }
+            }
+
+            getMoreOrEnd();
+        }
+    }
+
+    private static <RowT> Collection<RowT> lookup(
+            RowT row,
+            RowHandler<RowT> handler,
+            Map<Object, Object> hashStore,
+            Collection<Integer> leftJoinPositions,
+            boolean processTouched
+    ) {
+        Map<Object, Object> next = hashStore;
+        int processed = 0;
+        Collection<RowT> coll = Collections.emptyList();
+
+        for (Integer entry : leftJoinPositions) {
+            Object ent = handler.get(entry, row);
+
+            if (ent == null) {
+                return Collections.emptyList();
+            }
+
+            Object next0 = next.get(ent);
+
+            if (next0 == null) {
+                return Collections.emptyList();
+            }
+
+            processed++;
+            if (processed == leftJoinPositions.size()) {
+                coll = (Collection<RowT>) next.get(ent);
+
+                if (processTouched) {
+                    ((TouchedList<RowT>) coll).touched = true;
+                }
+            } else {
+                next = (Map<Object, Object>) next0;
+            }
+        }
+
+        return coll;
+    }
+
+    private static <RowT> List<RowT> getUntouched(Map<Object, Object> entries, 
@Nullable List<RowT> out) {
+        if (out == null) {
+            out = new ArrayList<>();
+        }
+
+        for (Map.Entry<Object, Object> ent : entries.entrySet()) {
+            if (ent.getValue() instanceof Collection) {
+                TouchedList<RowT> coll = (TouchedList<RowT>) ent.getValue();
+                if (!coll.touched) {
+                    out.addAll(coll);
+                }
+            } else {
+                getUntouched((Map<Object, Object>) ent.getValue(), out);
+            }
+        }
+        return out;
+    }
+
+    private static <RowT> void resetTouched(Map<Object, Object> entries) {
+        for (Map.Entry<Object, Object> ent : entries.entrySet()) {
+            if (ent.getValue() instanceof Collection) {
+                TouchedList<RowT> coll = (TouchedList<RowT>) ent.getValue();
+                if (coll.touched) {
+                    coll.touched = false;
+                }
+            } else {
+                resetTouched((Map<Object, Object>) ent.getValue());
+            }
+        }
+    }
+
+    @Override
+    protected void pushRight(RowT row) throws Exception {
+        assert downstream() != null;
+        assert waitingRight > 0;
+
+        checkState();
+
+        waitingRight--;
+
+        Map<Object, Object> next = hashStore;
+        int processed = 0;
+        for (Integer entry : rightJoinPositions) {
+            processed++;
+            Object ent = handler.get(entry, row);
+            if (processed == rightJoinPositions.size()) {
+                Collection<RowT> raw = touchResults
+                        ? (Collection<RowT>) next.computeIfAbsent(ent, k -> 
new TouchedList<>())
+                        : (Collection<RowT>) next.computeIfAbsent(ent, k -> 
new ArrayList<>());
+
+                raw.add(row);
+            } else {
+                next = (Object2ObjectOpenHashMap<Object, Object>) 
next.computeIfAbsent(ent, k -> new Object2ObjectOpenHashMap<>());
+            }
+        }
+
+        if (waitingRight == 0) {
+            rightSource().request(waitingRight = inBufSize);
+        }
+    }

Review Comment:
   ok, i rewrite this implementation and found significant performance boost, 
lets stay with new implementation although it would be difficult to implement 
"is [not] distinct" and seems we need to implement different execution node for 
such a case.
   
   new performance run:
   ```
   Benchmark                      (clusterSize)  (fsync)  Mode  Cnt       Score 
     Error  Units
   SqlJoinBenchmark.leftHashJoin              1    false  avgt   20  265234.571 
± 3363.419  us/op
   ```
   previous is actual in PR header, possible i need to refresh it, after review 
was completed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to