KYLIN-976 Support multiple implementations for one aggr function

Project: http://git-wip-us.apache.org/repos/asf/kylin/repo
Commit: http://git-wip-us.apache.org/repos/asf/kylin/commit/2d60c789
Tree: http://git-wip-us.apache.org/repos/asf/kylin/tree/2d60c789
Diff: http://git-wip-us.apache.org/repos/asf/kylin/diff/2d60c789

Branch: refs/heads/1.x-HBase1.1.3
Commit: 2d60c7896dbda22758d0d13b8d4af928e41b0017
Parents: 1b51eb5
Author: lidongsjtu <don...@ebay.com>
Authored: Sun Dec 27 15:13:52 2015 +0800
Committer: lidongsjtu <don...@ebay.com>
Committed: Sun Dec 27 15:13:52 2015 +0800

----------------------------------------------------------------------
 .../kylin/measure/MeasureTypeFactory.java       | 96 ++++++++++++++++----
 .../kylin/query/relnode/OLAPAggregateRel.java   | 93 +++++++++++++------
 .../apache/kylin/query/routing/QueryRouter.java |  7 --
 3 files changed, 141 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/kylin/blob/2d60c789/metadata/src/main/java/org/apache/kylin/measure/MeasureTypeFactory.java
----------------------------------------------------------------------
diff --git 
a/metadata/src/main/java/org/apache/kylin/measure/MeasureTypeFactory.java 
b/metadata/src/main/java/org/apache/kylin/measure/MeasureTypeFactory.java
index 0784f91..2eaf366 100644
--- a/metadata/src/main/java/org/apache/kylin/measure/MeasureTypeFactory.java
+++ b/metadata/src/main/java/org/apache/kylin/measure/MeasureTypeFactory.java
@@ -31,19 +31,8 @@ import com.google.common.collect.Maps;
 
 abstract public class MeasureTypeFactory<T> {
 
-    abstract public MeasureType<T> createMeasureType(String funcName, DataType 
dataType);
-
-    abstract public String getAggrFunctionName();
-
-    abstract public String getAggrDataTypeName();
-
-    abstract public Class<? extends DataTypeSerializer<T>> 
getAggrDataTypeSerializer();
-
-    // 
============================================================================
-
-
-    private static Map<String, MeasureTypeFactory<?>> factories = 
Maps.newHashMap();
-    private static MeasureTypeFactory<?> defaultFactory = new 
BasicMeasureType.Factory();
+    private static Map<String, List<MeasureTypeFactory<?>>> factories = 
Maps.newHashMap();
+    private static List<MeasureTypeFactory<?>> defaultFactory = 
Lists.newArrayListWithCapacity(2);
 
     static {
         init();
@@ -67,16 +56,27 @@ abstract public class MeasureTypeFactory<T> {
 
         // register factories & data type serializers
         for (MeasureTypeFactory<?> factory : factoryInsts) {
-            String funcName = factory.getAggrFunctionName().toUpperCase();
-            String dataTypeName = factory.getAggrDataTypeName().toLowerCase();
+            String funcName = factory.getAggrFunctionName();
+            if (funcName.equals(funcName.toUpperCase()) == false)
+                throw new IllegalArgumentException("Aggregation function name 
'" + funcName + "' must be in upper case");
+            String dataTypeName = factory.getAggrDataTypeName();
+            if (dataTypeName.equals(dataTypeName.toLowerCase()) == false)
+                throw new IllegalArgumentException("Aggregation data type name 
'" + dataTypeName + "' must be in lower case");
             Class<? extends DataTypeSerializer<?>> serializer = 
factory.getAggrDataTypeSerializer();
 
             DataType.register(dataTypeName);
             DataTypeSerializer.register(dataTypeName, serializer);
-            factories.put(funcName, factory);
+            List<MeasureTypeFactory<?>> list = factories.get(funcName);
+            if (list == null)
+                factories.put(funcName, list = 
Lists.newArrayListWithCapacity(2));
+            list.add(factory);
         }
+
+        defaultFactory.add(new BasicMeasureType.Factory());
     }
 
+    // 
============================================================================
+
     public static MeasureType<?> create(String funcName, String dataType) {
         return create(funcName, DataType.getInstance(dataType));
     }
@@ -84,10 +84,70 @@ abstract public class MeasureTypeFactory<T> {
     public static MeasureType<?> create(String funcName, DataType dataType) {
         funcName = funcName.toUpperCase();
 
-        MeasureTypeFactory<?> factory = factories.get(funcName);
+        List<MeasureTypeFactory<?>> factory = factories.get(funcName);
         if (factory == null)
             factory = defaultFactory;
 
-        return factory.createMeasureType(funcName, dataType);
+        // a special case where in early stage of sql parsing, the data type 
is unknown; only needRewrite() is required at that stage
+        if (dataType == null) {
+            return new NeedRewriteOnlyMeasureType(funcName, factory);
+        }
+
+        // the normal case, only one factory for a function
+        if (factory.size() == 1) {
+            return factory.get(0).createMeasureType(funcName, dataType);
+        }
+
+        // sometimes multiple factories are registered for the same function, 
then data types must tell them apart
+        for (MeasureTypeFactory<?> f : factory) {
+            if (f.getAggrDataTypeName().equals(dataType.getName()))
+                return f.createMeasureType(funcName, dataType);
+        }
+        throw new IllegalStateException();
+    };
+
+    abstract public MeasureType<T> createMeasureType(String funcName, DataType 
dataType);
+
+    abstract public String getAggrFunctionName();
+
+    abstract public String getAggrDataTypeName();
+
+    abstract public Class<? extends DataTypeSerializer<T>> 
getAggrDataTypeSerializer();
+
+    @SuppressWarnings("rawtypes")
+    private static class NeedRewriteOnlyMeasureType extends MeasureType {
+
+        private Boolean needRewrite;
+
+        public NeedRewriteOnlyMeasureType(String funcName, 
List<MeasureTypeFactory<?>> factory) {
+            for (MeasureTypeFactory<?> f : factory) {
+                boolean b = f.createMeasureType(funcName, null).needRewrite();
+                if (needRewrite == null)
+                    needRewrite = Boolean.valueOf(b);
+                else if (needRewrite.booleanValue() != b)
+                    throw new IllegalStateException("needRewrite() of factorys 
" + factory + " does not have consensus");
+            }
+        }
+
+        @Override
+        public MeasureIngester newIngester() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public MeasureAggregator newAggregator() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public boolean needRewrite() {
+            return needRewrite;
+        }
+
+        @Override
+        public Class getRewriteCalciteAggrFunctionClass() {
+            throw new UnsupportedOperationException();
+        }
+
     }
 }

http://git-wip-us.apache.org/repos/asf/kylin/blob/2d60c789/query/src/main/java/org/apache/kylin/query/relnode/OLAPAggregateRel.java
----------------------------------------------------------------------
diff --git 
a/query/src/main/java/org/apache/kylin/query/relnode/OLAPAggregateRel.java 
b/query/src/main/java/org/apache/kylin/query/relnode/OLAPAggregateRel.java
index f225fe2..9e75106 100644
--- a/query/src/main/java/org/apache/kylin/query/relnode/OLAPAggregateRel.java
+++ b/query/src/main/java/org/apache/kylin/query/relnode/OLAPAggregateRel.java
@@ -25,6 +25,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import com.google.common.collect.Lists;
 import org.apache.calcite.adapter.enumerable.EnumerableAggregate;
 import org.apache.calcite.adapter.enumerable.EnumerableConvention;
 import org.apache.calcite.adapter.enumerable.EnumerableRel;
@@ -56,6 +57,7 @@ import org.apache.calcite.util.ImmutableBitSet;
 import org.apache.calcite.util.Util;
 import org.apache.kylin.metadata.model.ColumnDesc;
 import org.apache.kylin.metadata.model.FunctionDesc;
+import org.apache.kylin.metadata.model.MeasureDesc;
 import org.apache.kylin.metadata.model.ParameterDesc;
 import org.apache.kylin.metadata.model.TableDesc;
 import org.apache.kylin.metadata.model.TblColRef;
@@ -133,11 +135,10 @@ public class OLAPAggregateRel extends Aggregate 
implements OLAPRel {
         this.columnRowType = buildColumnRowType();
         this.afterAggregate = this.context.afterAggregate;
 
-        // only translate the first aggregation
+        // only translate the innermost aggregation
         if (!this.afterAggregate) {
             translateGroupBy();
-            fillbackOptimizedColumn();
-            translateAggregation();
+            this.context.aggregations.addAll(this.aggregations);
             this.context.afterAggregate = true;
         } else {
             for (AggregateCall aggCall : aggCalls) {
@@ -224,14 +225,71 @@ public class OLAPAggregateRel extends Aggregate 
implements OLAPRel {
         context.groupByColumns.addAll(this.groups);
     }
 
+    @Override
+    public void implementRewrite(RewriteImplementor implementor) {
+        // only rewrite the innermost aggregation
+        if (!this.afterAggregate) {
+            translateAggregation();
+            buildRewriteFieldsAndMetricsColumns();
+        }
+
+        implementor.visitChild(this, getInput());
+
+        // only rewrite the innermost aggregation
+        if (!this.afterAggregate && 
RewriteImplementor.needRewrite(this.context)) {
+            // rewrite the aggCalls
+            this.rewriteAggCalls = new 
ArrayList<AggregateCall>(aggCalls.size());
+            for (int i = 0; i < this.aggCalls.size(); i++) {
+                AggregateCall aggCall = this.aggCalls.get(i);
+                FunctionDesc cubeFunc = this.context.aggregations.get(i);
+                if (cubeFunc.needRewrite()) {
+                    aggCall = rewriteAggregateCall(aggCall, cubeFunc);
+                }
+                this.rewriteAggCalls.add(aggCall);
+            }
+        }
+
+        // rebuild rowType & columnRowType
+        this.rowType = this.deriveRowType();
+        this.columnRowType = this.buildColumnRowType();
+    }
+
     private void translateAggregation() {
+        // now the realization is known, replace aggregations with what's 
defined on MeasureDesc
+        List<MeasureDesc> measures = this.context.realization.getMeasures();
+        List<FunctionDesc> newAggrs = Lists.newArrayList();
+        for (FunctionDesc aggFunc : this.aggregations) {
+            newAggrs.add(findInMeasures(aggFunc, measures));
+        }
+        this.aggregations.clear();
+        this.aggregations.addAll(newAggrs);
+        this.context.aggregations.clear();
+        this.context.aggregations.addAll(newAggrs);
+    }
+
+    private FunctionDesc findInMeasures(FunctionDesc aggFunc, 
List<MeasureDesc> measures) {
+        for (MeasureDesc m : measures) {
+            if (aggFunc.equals(m.getFunction()))
+                return m.getFunction();
+        }
+        return aggFunc;
+    }
+
+    private void buildRewriteFieldsAndMetricsColumns() {
+        fillbackOptimizedColumn();
+
         ColumnRowType inputColumnRowType = ((OLAPRel) 
getInput()).getColumnRowType();
         for (int i = 0; i < this.aggregations.size(); i++) {
             FunctionDesc aggFunc = this.aggregations.get(i);
-            context.aggregations.add(aggFunc);
+
+            if (aggFunc.isDimensionAsMetric()) {
+                
this.context.groupByColumns.addAll(aggFunc.getParameter().getColRefs());
+                continue; // skip rewrite, let calcite handle
+            }
+
             if (aggFunc.needRewrite()) {
                 String rewriteFieldName = aggFunc.getRewriteFieldName();
-                context.rewriteFields.put(rewriteFieldName, null);
+                this.context.rewriteFields.put(rewriteFieldName, null);
 
                 TblColRef column = buildRewriteColumn(aggFunc);
                 this.context.metricsColumns.add(column);
@@ -263,31 +321,6 @@ public class OLAPAggregateRel extends Aggregate implements 
OLAPRel {
         }
     }
 
-    @Override
-    public void implementRewrite(RewriteImplementor implementor) {
-        implementor.visitChild(this, getInput());
-
-        // only rewrite the first aggregation
-        if (!this.afterAggregate && 
RewriteImplementor.needRewrite(this.context)) {
-            // rewrite the aggCalls
-            this.rewriteAggCalls = new 
ArrayList<AggregateCall>(aggCalls.size());
-            for (int i = 0; i < this.aggCalls.size(); i++) {
-                AggregateCall aggCall = this.aggCalls.get(i);
-                FunctionDesc cubeFunc = this.context.aggregations.get(i);
-                if (cubeFunc.needRewrite()) {
-                    aggCall = rewriteAggregateCall(aggCall, cubeFunc);
-                }
-                this.rewriteAggCalls.add(aggCall);
-            }
-        }
-
-        // rebuild rowType & columnRowType
-        //ClassUtil.updateFinalField(Aggregate.class, "aggCalls", this, 
rewriteAggCalls);
-        this.rowType = this.deriveRowType(); // this does not work coz 
super.aggCalls is final
-        this.columnRowType = this.buildColumnRowType();
-
-    }
-
     private AggregateCall rewriteAggregateCall(AggregateCall aggCall, 
FunctionDesc func) {
 
         // rebuild parameters

http://git-wip-us.apache.org/repos/asf/kylin/blob/2d60c789/query/src/main/java/org/apache/kylin/query/routing/QueryRouter.java
----------------------------------------------------------------------
diff --git 
a/query/src/main/java/org/apache/kylin/query/routing/QueryRouter.java 
b/query/src/main/java/org/apache/kylin/query/routing/QueryRouter.java
index 7493e08..59469f6 100644
--- a/query/src/main/java/org/apache/kylin/query/routing/QueryRouter.java
+++ b/query/src/main/java/org/apache/kylin/query/routing/QueryRouter.java
@@ -84,13 +84,6 @@ public class QueryRouter {
             if (inf instanceof DimensionAsMeasure) {
                 FunctionDesc functionDesc = ((DimensionAsMeasure) 
inf).getMeasureFunction();
                 functionDesc.setDimensionAsMetric(true);
-                
olapContext.rewriteFields.remove(functionDesc.getRewriteFieldName());
-                for (TblColRef col : functionDesc.getParameter().getColRefs()) 
{
-                    if (col != null) {
-                        olapContext.metricsColumns.remove(col);
-                        olapContext.groupByColumns.add(col);
-                    }
-                }
                 logger.info("Adjust DimensionAsMeasure for " + functionDesc);
             }
         }

Reply via email to