This is an automated email from the ASF dual-hosted git repository.
aloalt pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-wayang-website.git
The following commit(s) were added to refs/heads/main by this push:
new 27193011 Add machine learning opt example to guides (#30)
27193011 is described below
commit 2719301110ce21a551164add646c3575414ed0ac
Author: Juri Petersen <[email protected]>
AuthorDate: Tue Jan 30 10:23:16 2024 +0100
Add machine learning opt example to guides (#30)
---
docs/guide/usage-examples.md | 197 ++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 194 insertions(+), 3 deletions(-)
diff --git a/docs/guide/usage-examples.md b/docs/guide/usage-examples.md
index 9725185a..ed307b30 100644
--- a/docs/guide/usage-examples.md
+++ b/docs/guide/usage-examples.md
@@ -25,9 +25,200 @@ id: examples
This section provides a set of examples to illustrate how to use Apache Wayang
for different tasks.
-## Example 1: [Title]
-- Description of the task
-- Step-by-step guide
+## Example 1: Machine Learning for query optimization in Apache Wayang
+Apache Wayang can be customized with concrete
+implementations of the `EstimatableCost` interface in order to optimize
+for a desired metric. The implementation can be enabled by providing it
+to a `Configuration`.
+
+```java
+public class CustomEstimatableCost implements EstimatableCost {
+ /* Provide concrete implementations to match desired cost function(s)
+ * by implementing the interface in this class.
+ */
+}
+public class WordCount {
+ public static void main(String[] args) {
+ /* Create a Wayang context and specify the platforms Wayang will
consider */
+ Configuration config = new Configuration();
+ /* Provision of a EstimatableCost that implements the interface.*/
+ config.setCostModel(new CustomEstimatableCost());
+ WayangContext wayangContext = new WayangContext(config)
+ .withPlugin(Java.basicPlugin())
+ .withPlugin(Spark.basicPlugin());
+ /*... omitted */
+ }
+}
+```
+
+In combination with an encoding scheme and a third party package to load
+ML models, the following example shows how to predict runtimes of query
+execution plans runtimes in Apache Wayang (incubating):
+
+```java
+public class MLCost implements EstimatableCost {
+ public EstimatableCostFactory getFactory() {
+ return new Factory();
+ }
+
+ public static class Factory implements EstimatableCostFactory {
+ @Override public EstimatableCost makeCost() {
+ return new MLCost();
+ }
+ }
+
+ @Override public ProbabilisticDoubleInterval
getEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
+ try {
+ Configuration config = plan
+ .getOptimizationContext()
+ .getConfiguration();
+ OrtMLModel model = OrtMLModel.getInstance(config);
+
+ return ProbabilisticDoubleInterval.ofExactly(
+ model.runModel(OneHotEncoder.encode(plan))
+ );
+ } catch(Exception e) {
+ return ProbabilisticDoubleInterval.zero;
+ }
+ }
+
+ @Override public ProbabilisticDoubleInterval
getParallelEstimate(PlanImplementation plan, boolean isOverheadIncluded) {
+ try {
+ Configuration config = plan
+ .getOptimizationContext()
+ .getConfiguration();
+ OrtMLModel model = OrtMLModel.getInstance(config);
+
+ return ProbabilisticDoubleInterval.ofExactly(
+ model.runModel(OneHotEncoder.encode(plan))
+ );
+ } catch(Exception e) {
+ return ProbabilisticDoubleInterval.zero;
+ }
+ }
+
+ /** Returns a squashed cost estimate. */
+ @Override public double getSquashedEstimate(PlanImplementation plan,
boolean isOverheadIncluded) {
+ try {
+ Configuration config = plan
+ .getOptimizationContext()
+ .getConfiguration();
+ OrtMLModel model = OrtMLModel.getInstance(config);
+
+ return model.runModel(OneHotEncoder.encode(plan));
+ } catch(Exception e) {
+ return 0;
+ }
+ }
+
+ @Override public double getSquashedParallelEstimate(PlanImplementation
plan, boolean isOverheadIncluded) {
+ try {
+ Configuration config = plan
+ .getOptimizationContext()
+ .getConfiguration();
+ OrtMLModel model = OrtMLModel.getInstance(config);
+
+ return model.runModel(OneHotEncoder.encode(plan));
+ } catch(Exception e) {
+ return 0;
+ }
+ }
+
+ @Override public Tuple<List<ProbabilisticDoubleInterval>, List<Double>>
getParallelOperatorJunctionAllCostEstimate(PlanImplementation plan, Operator
operator) {
+ List<ProbabilisticDoubleInterval> intervalList = new
ArrayList<ProbabilisticDoubleInterval>();
+ List<Double> doubleList = new ArrayList<Double>();
+ intervalList.add(this.getEstimate(plan, true));
+ doubleList.add(this.getSquashedEstimate(plan, true));
+
+ return new Tuple<>(intervalList, doubleList);
+ }
+
+ public PlanImplementation pickBestExecutionPlan(
+ Collection<PlanImplementation> executionPlans,
+ ExecutionPlan existingPlan,
+ Set<Channel> openChannels,
+ Set<ExecutionStage> executedStages) {
+ final PlanImplementation bestPlanImplementation =
executionPlans.stream()
+ .reduce((p1, p2) -> {
+ final double t1 = p1.getSquashedCostEstimate();
+ final double t2 = p2.getSquashedCostEstimate();
+ return t1 < t2 ? p1 : p2;
+ })
+ .orElseThrow(() -> new WayangException("Could not find an
execution plan."));
+ return bestPlanImplementation;
+ }
+}
+```
+
+Third-party packages such as `OnnxRuntime` can be used to load
+pre-trained `.onnx` files that contain desired ML models.
+
+```java
+public class OrtMLModel {
+
+ private static OrtMLModel INSTANCE;
+
+ private OrtSession session;
+ private OrtEnvironment env;
+
+ private final Map<String, OnnxTensor> inputMap = new HashMap<>();
+ private final Set<String> requestedOutputs = new HashSet<>();
+
+ public static OrtMLModel getInstance(Configuration configuration) throws
OrtException {
+ if (INSTANCE == null) {
+ INSTANCE = new OrtMLModel(configuration);
+ }
+
+ return INSTANCE;
+ }
+
+ private OrtMLModel(Configuration configuration) throws OrtException {
+
this.loadModel(configuration.getStringProperty("wayang.ml.model.file"));
+ }
+
+ public void loadModel(String filePath) throws OrtException {
+ if (this.env == null) {
+ this.env = OrtEnvironment.getEnvironment();
+ }
+
+ if (this.session == null) {
+ this.session = env.createSession(filePath, new
OrtSession.SessionOptions());
+ }
+ }
+
+ public void closeSession() throws OrtException {
+ this.session.close();
+ this.env.close();
+ }
+
+ /**
+ * @param encodedVector
+ * @return NaN on error, and a predicted cost on any other value.
+ * @throws OrtException
+ */
+ public double runModel(Vector<Long> encodedVector) throws OrtException {
+ double costPrediction;
+
+ OnnxTensor tensor = OnnxTensor.createTensor(env, encodedVector);
+ this.inputMap.put("input", tensor);
+ this.requestedOutputs.add("output");
+
+ BiFunction<Result, String, Double> unwrapFunc = (r, s) -> {
+ try {
+ return ((double[]) r.get(s).get().getValue())[0];
+ } catch (OrtException e) {
+ return Double.NaN;
+ }
+ };
+
+ try (Result r = session.run(inputMap, requestedOutputs)) {
+ costPrediction = unwrapFunc.apply(r, "output");
+ }
+
+ return costPrediction;
+ }
+}
+```
## Example 2: [Title]
- Description of the task