This is an automated email from the ASF dual-hosted git repository.

aloalt pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-wayang-website.git


The following commit(s) were added to refs/heads/main by this push:
     new 27193011 Add machine learning opt example to guides (#30)
27193011 is described below

commit 2719301110ce21a551164add646c3575414ed0ac
Author: Juri Petersen <[email protected]>
AuthorDate: Tue Jan 30 10:23:16 2024 +0100

    Add machine learning opt example to guides (#30)
---
 docs/guide/usage-examples.md | 197 ++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 194 insertions(+), 3 deletions(-)

diff --git a/docs/guide/usage-examples.md b/docs/guide/usage-examples.md
index 9725185a..ed307b30 100644
--- a/docs/guide/usage-examples.md
+++ b/docs/guide/usage-examples.md
@@ -25,9 +25,200 @@ id: examples
 
 This section provides a set of examples to illustrate how to use Apache Wayang 
for different tasks.
 
-## Example 1: [Title]
-- Description of the task
-- Step-by-step guide
+## Example 1: Machine Learning for query optimization in Apache Wayang
+Apache Wayang can be customized with concrete
+implementations of the `EstimatableCost` interface in order to optimize
+for a desired metric.  The implementation can be enabled by providing it
+to a `Configuration`.
+
+```java
+public class CustomEstimatableCost implements EstimatableCost {
+    /* Provide concrete implementations to match desired cost function(s)
+     * by implementing the interface in this class.
+     */
+}
+public class WordCount {
+    public static void main(String[] args) {
+        /* Create a Wayang context and specify the platforms Wayang will 
consider */
+        Configuration config = new Configuration();
+        /* Provision of a EstimatableCost that implements the interface.*/
+        config.setCostModel(new CustomEstimatableCost());
+        WayangContext wayangContext = new WayangContext(config)
+                .withPlugin(Java.basicPlugin())
+                .withPlugin(Spark.basicPlugin());
+        /*... omitted */
+    }
+}
+```
+
+In combination with an encoding scheme and a third party package to load
+ML models, the following example shows how to predict runtimes of query
+execution plans runtimes in Apache Wayang (incubating):
+
+```java
+public class MLCost implements EstimatableCost {
+    public EstimatableCostFactory getFactory() {
+        return new Factory();
+    }
+
+    public static class Factory implements EstimatableCostFactory {
+        @Override public EstimatableCost makeCost() {
+            return new MLCost();
+        }
+    }
+
+    @Override public ProbabilisticDoubleInterval 
getEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
+        try {
+            Configuration config = plan
+                .getOptimizationContext()
+                .getConfiguration();
+            OrtMLModel model = OrtMLModel.getInstance(config);
+
+            return ProbabilisticDoubleInterval.ofExactly(
+                model.runModel(OneHotEncoder.encode(plan))
+            );
+        } catch(Exception e) {
+            return ProbabilisticDoubleInterval.zero;
+        }
+    }
+
+    @Override public ProbabilisticDoubleInterval 
getParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
+        try {
+            Configuration config = plan
+                .getOptimizationContext()
+                .getConfiguration();
+            OrtMLModel model = OrtMLModel.getInstance(config);
+
+            return ProbabilisticDoubleInterval.ofExactly(
+                model.runModel(OneHotEncoder.encode(plan))
+            );
+        } catch(Exception e) {
+            return ProbabilisticDoubleInterval.zero;
+        }
+    }
+
+    /** Returns a squashed cost estimate. */
+    @Override public double getSquashedEstimate(PlanImplementation plan, 
boolean isOverheadIncluded) {
+        try {
+            Configuration config = plan
+                .getOptimizationContext()
+                .getConfiguration();
+            OrtMLModel model = OrtMLModel.getInstance(config);
+
+            return model.runModel(OneHotEncoder.encode(plan));
+        } catch(Exception e) {
+            return 0;
+        }
+    }
+
+    @Override public double getSquashedParallelEstimate(PlanImplementation 
plan, boolean isOverheadIncluded) {
+        try {
+            Configuration config = plan
+                .getOptimizationContext()
+                .getConfiguration();
+            OrtMLModel model = OrtMLModel.getInstance(config);
+
+            return model.runModel(OneHotEncoder.encode(plan));
+        } catch(Exception e) {
+            return 0;
+        }
+    }
+
+    @Override public Tuple<List<ProbabilisticDoubleInterval>, List<Double>> 
getParallelOperatorJunctionAllCostEstimate(PlanImplementation plan, Operator 
operator) {
+        List<ProbabilisticDoubleInterval> intervalList = new 
ArrayList<ProbabilisticDoubleInterval>();
+        List<Double> doubleList = new ArrayList<Double>();
+        intervalList.add(this.getEstimate(plan, true));
+        doubleList.add(this.getSquashedEstimate(plan, true));
+
+        return new Tuple<>(intervalList, doubleList);
+    }
+
+    public PlanImplementation pickBestExecutionPlan(
+            Collection<PlanImplementation> executionPlans,
+            ExecutionPlan existingPlan,
+            Set<Channel> openChannels,
+            Set<ExecutionStage> executedStages) {
+        final PlanImplementation bestPlanImplementation = 
executionPlans.stream()
+                .reduce((p1, p2) -> {
+                    final double t1 = p1.getSquashedCostEstimate();
+                    final double t2 = p2.getSquashedCostEstimate();
+                    return t1 < t2 ? p1 : p2;
+                })
+                .orElseThrow(() -> new WayangException("Could not find an 
execution plan."));
+        return bestPlanImplementation;
+    }
+}
+```
+
+Third-party packages such as `OnnxRuntime` can be used to load
+pre-trained `.onnx` files that contain desired ML models.
+
+```java
+public class OrtMLModel {
+
+    private static OrtMLModel INSTANCE;
+
+    private OrtSession session;
+    private OrtEnvironment env;
+
+    private final Map<String, OnnxTensor> inputMap = new HashMap<>();
+    private final Set<String> requestedOutputs = new HashSet<>();
+
+    public static OrtMLModel getInstance(Configuration configuration) throws 
OrtException {
+        if (INSTANCE == null) {
+            INSTANCE = new OrtMLModel(configuration);
+        }
+
+        return INSTANCE;
+    }
+
+    private OrtMLModel(Configuration configuration) throws OrtException {
+        
this.loadModel(configuration.getStringProperty("wayang.ml.model.file"));
+    }
+
+    public void loadModel(String filePath) throws OrtException {
+        if (this.env == null) {
+            this.env = OrtEnvironment.getEnvironment();
+        }
+
+        if (this.session == null) {
+            this.session = env.createSession(filePath, new 
OrtSession.SessionOptions());
+        }
+    }
+
+    public void closeSession() throws OrtException {
+        this.session.close();
+        this.env.close();
+    }
+
+    /**
+     * @param encodedVector
+     * @return NaN on error, and a predicted cost on any other value.
+     * @throws OrtException
+     */
+    public double runModel(Vector<Long> encodedVector) throws OrtException {
+        double costPrediction;
+
+        OnnxTensor tensor = OnnxTensor.createTensor(env, encodedVector);
+        this.inputMap.put("input", tensor);
+        this.requestedOutputs.add("output");
+
+        BiFunction<Result, String, Double> unwrapFunc = (r, s) -> {
+            try {
+                return ((double[]) r.get(s).get().getValue())[0];
+            } catch (OrtException e) {
+                return Double.NaN;
+            }
+        };
+
+        try (Result r = session.run(inputMap, requestedOutputs)) {
+            costPrediction = unwrapFunc.apply(r, "output");
+        }
+
+        return costPrediction;
+    }
+}
+```
 
 ## Example 2: [Title]
 - Description of the task

Reply via email to