This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new a9a019b2 [SEDONA-318] Switch raster serde to a deep copy 
implementation (#892)
a9a019b2 is described below

commit a9a019b23086ff251b1f5f8ff21e459ade4cbcfe
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Fri Jul 7 12:18:05 2023 +0800

    [SEDONA-318] Switch raster serde to a deep copy implementation (#892)
---
 .../common/raster/DeepCopiedRenderedImage.java     | 369 +++++++++++++++++++++
 .../org/apache/sedona/common/raster/Serde.java     |  59 ++--
 .../org/apache/sedona/common/raster/SerdeTest.java | 103 +++++-
 3 files changed, 496 insertions(+), 35 deletions(-)

diff --git 
a/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java
 
b/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java
new file mode 100644
index 00000000..baf62a86
--- /dev/null
+++ 
b/common/src/main/java/org/apache/sedona/common/raster/DeepCopiedRenderedImage.java
@@ -0,0 +1,369 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sedona.common.raster;
+
+import com.sun.media.jai.util.ImageUtil;
+
+import javax.media.jai.JAI;
+import javax.media.jai.PlanarImage;
+import javax.media.jai.RasterAccessor;
+import javax.media.jai.RasterFormatTag;
+import javax.media.jai.RemoteImage;
+import javax.media.jai.TileCache;
+import javax.media.jai.remote.SerializableState;
+import javax.media.jai.remote.SerializerFactory;
+import java.awt.Image;
+import java.awt.Point;
+import java.awt.Rectangle;
+import java.awt.image.ColorModel;
+import java.awt.image.DataBuffer;
+import java.awt.image.Raster;
+import java.awt.image.RenderedImage;
+import java.awt.image.SampleModel;
+import java.awt.image.WritableRaster;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.util.Enumeration;
+import java.util.Hashtable;
+import java.util.Vector;
+
+/**
+ * This class is mostly copied from {@link 
javax.media.jai.remote.SerializableRenderedImage}. We've removed the
+ * shallow copy support and fixed a bug of SerializableRenderedImage: When a 
deep-copied serializable rendered image
+ * object is being disposed, it tries to connect to the remote server. 
However, there is no remote server in deep-copy
+ * mode, so the dispose() method throws a java.net.SocketException.
+ */
+public final class DeepCopiedRenderedImage implements RenderedImage, 
Serializable {
+    private transient RenderedImage source;
+    private int minX;
+    private int minY;
+    private int width;
+    private int height;
+    private int minTileX;
+    private int minTileY;
+    private int numXTiles;
+    private int numYTiles;
+    private int tileWidth;
+    private int tileHeight;
+    private int tileGridXOffset;
+    private int tileGridYOffset;
+    private transient SampleModel sampleModel;
+    private transient ColorModel colorModel;
+    private transient Vector<RenderedImage> sources;
+    private transient Hashtable<String, Object> properties;
+    private Rectangle imageBounds;
+    private transient Raster imageRaster;
+
+    DeepCopiedRenderedImage() {
+        this.sampleModel = null;
+        this.colorModel = null;
+        this.sources = null;
+        this.properties = null;
+    }
+
+    public DeepCopiedRenderedImage(RenderedImage source) {
+        this(source, true);
+    }
+
+    private DeepCopiedRenderedImage(RenderedImage source, boolean 
checkDataBuffer) {
+        this.sampleModel = null;
+        this.colorModel = null;
+        this.sources = null;
+        this.properties = null;
+        if (source == null) {
+            throw new IllegalArgumentException("source cannot be null");
+        } else {
+            SampleModel sm = source.getSampleModel();
+            if (sm != null && SerializerFactory.getSerializer(sm.getClass()) 
== null) {
+                throw new IllegalArgumentException("sample model object is not 
serializable");
+            } else {
+                ColorModel cm = source.getColorModel();
+                if (cm != null && 
SerializerFactory.getSerializer(cm.getClass()) == null) {
+                    throw new IllegalArgumentException("color model object is 
not serializable");
+                } else {
+                    if (checkDataBuffer) {
+                        Raster ras = source.getTile(source.getMinTileX(), 
source.getMinTileY());
+                        if (ras != null) {
+                            DataBuffer db = ras.getDataBuffer();
+                            if (db != null && 
SerializerFactory.getSerializer(db.getClass()) == null) {
+                                throw new IllegalArgumentException("data 
buffer object is not serializable");
+                            }
+                        }
+                    }
+
+                    this.source = source;
+                    if (source instanceof RemoteImage) {
+                        throw new IllegalArgumentException("RemoteImage is not 
supported");
+                    }
+                    this.minX = source.getMinX();
+                    this.minY = source.getMinY();
+                    this.width = source.getWidth();
+                    this.height = source.getHeight();
+                    this.minTileX = source.getMinTileX();
+                    this.minTileY = source.getMinTileY();
+                    this.numXTiles = source.getNumXTiles();
+                    this.numYTiles = source.getNumYTiles();
+                    this.tileWidth = source.getTileWidth();
+                    this.tileHeight = source.getTileHeight();
+                    this.tileGridXOffset = source.getTileGridXOffset();
+                    this.tileGridYOffset = source.getTileGridYOffset();
+                    this.sampleModel = source.getSampleModel();
+                    this.colorModel = source.getColorModel();
+                    this.sources = new Vector<>();
+                    this.sources.add(source);
+                    this.properties = new Hashtable<>();
+                    String[] propertyNames = source.getPropertyNames();
+                    if (propertyNames != null) {
+                        for (String propertyName : propertyNames) {
+                            this.properties.put(propertyName, 
source.getProperty(propertyName));
+                        }
+                    }
+
+                    this.imageBounds = new Rectangle(this.minX, this.minY, 
this.width, this.height);
+                }
+            }
+        }
+    }
+
+    @Override
+    public ColorModel getColorModel() {
+        return this.colorModel;
+    }
+
+    @Override
+    public Raster getData() {
+        if (source == null) {
+            return this.getData(this.imageBounds);
+        } else {
+            return this.source.getData();
+        }
+    }
+
+    @Override
+    public Raster getData(Rectangle rect) {
+        if (source == null) {
+            return this.imageRaster.createChild(rect.x, rect.y, rect.width, 
rect.height, rect.x, rect.y, (int[])null);
+        } else {
+            return this.source.getData(rect);
+        }
+    }
+
+    @Override
+    public WritableRaster copyData(WritableRaster dest) {
+        if (source == null) {
+            Rectangle region;
+            if (dest == null) {
+                region = this.imageBounds;
+                SampleModel destSM = 
this.getSampleModel().createCompatibleSampleModel(region.width, region.height);
+                dest = Raster.createWritableRaster(destSM, new Point(region.x, 
region.y));
+            } else {
+                region = dest.getBounds().intersection(this.imageBounds);
+            }
+
+            if (!region.isEmpty()) {
+                int startTileX = PlanarImage.XToTileX(region.x, 
this.tileGridXOffset, this.tileWidth);
+                int startTileY = PlanarImage.YToTileY(region.y, 
this.tileGridYOffset, this.tileHeight);
+                int endTileX = PlanarImage.XToTileX(region.x + region.width - 
1, this.tileGridXOffset, this.tileWidth);
+                int endTileY = PlanarImage.YToTileY(region.y + region.height - 
1, this.tileGridYOffset, this.tileHeight);
+                SampleModel[] sampleModels = new 
SampleModel[]{this.getSampleModel()};
+                int tagID = RasterAccessor.findCompatibleTag(sampleModels, 
dest.getSampleModel());
+                RasterFormatTag srcTag = new 
RasterFormatTag(this.getSampleModel(), tagID);
+                RasterFormatTag dstTag = new 
RasterFormatTag(dest.getSampleModel(), tagID);
+
+                for(int ty = startTileY; ty <= endTileY; ++ty) {
+                    for(int tx = startTileX; tx <= endTileX; ++tx) {
+                        Raster tile = this.getTile(tx, ty);
+                        Rectangle subRegion = 
region.intersection(tile.getBounds());
+                        RasterAccessor s = new RasterAccessor(tile, subRegion, 
srcTag, this.getColorModel());
+                        RasterAccessor d = new RasterAccessor(dest, subRegion, 
dstTag, null);
+                        ImageUtil.copyRaster(s, d);
+                    }
+                }
+            }
+
+            return dest;
+        } else {
+            return this.source.copyData(dest);
+        }
+    }
+
+    @Override
+    public int getHeight() {
+        return this.height;
+    }
+
+    @Override
+    public int getMinTileX() {
+        return this.minTileX;
+    }
+
+    @Override
+    public int getMinTileY() {
+        return this.minTileY;
+    }
+
+    @Override
+    public int getMinX() {
+        return this.minX;
+    }
+
+    @Override
+    public int getMinY() {
+        return this.minY;
+    }
+
+    @Override
+    public int getNumXTiles() {
+        return this.numXTiles;
+    }
+
+    @Override
+    public int getNumYTiles() {
+        return this.numYTiles;
+    }
+
+    @Override
+    public Object getProperty(String name) {
+        Object property = this.properties.get(name);
+        return property == null ? Image.UndefinedProperty : property;
+    }
+
+    @Override
+    public String[] getPropertyNames() {
+        String[] names = null;
+        if (!this.properties.isEmpty()) {
+            names = new String[this.properties.size()];
+            Enumeration<String> keys = this.properties.keys();
+            int index = 0;
+            while (keys.hasMoreElements()) {
+                names[index++] = keys.nextElement();
+            }
+        }
+
+        return names;
+    }
+
+    @Override
+    public SampleModel getSampleModel() {
+        return this.sampleModel;
+    }
+
+    @Override
+    public Vector<RenderedImage> getSources() {
+        return this.sources;
+    }
+
+    @Override
+    public Raster getTile(int tileX, int tileY) {
+        if (source == null) {
+            TileCache cache = JAI.getDefaultInstance().getTileCache();
+            if (cache != null) {
+                Raster tile = cache.getTile(this, tileX, tileY);
+                if (tile != null) {
+                    return tile;
+                }
+            }
+
+            Rectangle imageBounds = new Rectangle(this.getMinX(), 
this.getMinY(), this.getWidth(), this.getHeight());
+            Rectangle destRect = imageBounds.intersection(new 
Rectangle(this.tileXToX(tileX), this.tileYToY(tileY), this.getTileWidth(), 
this.getTileHeight()));
+            Raster tile = this.getData(destRect);
+            if (cache != null) {
+                cache.add(this, tileX, tileY, tile);
+            }
+
+            return tile;
+        } else {
+            return this.source.getTile(tileX, tileY);
+        }
+    }
+
+    private int tileXToX(int tx) {
+        return PlanarImage.tileXToX(tx, this.getTileGridXOffset(), 
this.getTileWidth());
+    }
+
+    private int tileYToY(int ty) {
+        return PlanarImage.tileYToY(ty, this.getTileGridYOffset(), 
this.getTileHeight());
+    }
+
+    @Override
+    public int getTileGridXOffset() {
+        return this.tileGridXOffset;
+    }
+
+    @Override
+    public int getTileGridYOffset() {
+        return this.tileGridYOffset;
+    }
+
+    @Override
+    public int getTileHeight() {
+        return this.tileHeight;
+    }
+
+    @Override
+    public int getTileWidth() {
+        return this.tileWidth;
+    }
+
+    @Override
+    public int getWidth() {
+        return this.width;
+    }
+
+    @SuppressWarnings("unchecked")
+    private void writeObject(ObjectOutputStream out) throws IOException {
+        out.defaultWriteObject();
+
+        // Prepare serialize properties. non-serializable properties won't be 
serialized.
+        Hashtable<String, Object> propertyTable = this.properties;
+        boolean propertiesCloned = false;
+        Enumeration<String> keys = propertyTable.keys();
+        while (keys.hasMoreElements()) {
+            Object key = keys.nextElement();
+            if (!(this.properties.get(key) instanceof Serializable)) {
+                if (!propertiesCloned) {
+                    propertyTable = (Hashtable<String, Object>) 
this.properties.clone();
+                    propertiesCloned = true;
+                }
+                propertyTable.remove(key);
+            }
+        }
+
+        out.writeObject(SerializerFactory.getState(this.sampleModel, null));
+        out.writeObject(SerializerFactory.getState(this.colorModel, null));
+        out.writeObject(propertyTable);
+        if (this.source != null) {
+            out.writeObject(SerializerFactory.getState(this.source.getData(), 
null));
+        } else {
+            out.writeObject(SerializerFactory.getState(imageRaster, null));
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    private void readObject(ObjectInputStream in) throws IOException, 
ClassNotFoundException {
+        this.source = null;
+        this.colorModel = null;
+        in.defaultReadObject();
+
+        SerializableState smState = (SerializableState)in.readObject();
+        this.sampleModel = (SampleModel)smState.getObject();
+        SerializableState cmState = (SerializableState)in.readObject();
+        this.colorModel = (ColorModel)cmState.getObject();
+        this.properties = (Hashtable<String, Object>)in.readObject();
+        SerializableState rasState = (SerializableState)in.readObject();
+        this.imageRaster = (Raster)rasState.getObject();
+    }
+}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/Serde.java 
b/common/src/main/java/org/apache/sedona/common/raster/Serde.java
index 2a9e4d7e..9ca3959d 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/Serde.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/Serde.java
@@ -16,38 +16,59 @@ package org.apache.sedona.common.raster;
 import org.geotools.coverage.grid.GridCoverage2D;
 import org.geotools.coverage.grid.GridCoverageFactory;
 
+import javax.media.jai.PlanarImage;
+import javax.media.jai.RenderedImageAdapter;
 import javax.media.jai.remote.SerializableRenderedImage;
-
 import java.awt.image.RenderedImage;
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import java.lang.reflect.Field;
 
 public class Serde {
 
+    static final Field field;
+    static {
+        try {
+            field = GridCoverage2D.class.getDeclaredField("serializedImage");
+            field.setAccessible(true);
+        } catch (NoSuchFieldException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
     public static byte[] serialize(GridCoverage2D raster) throws IOException {
         // GridCoverage2D created by GridCoverage2DReaders contain references 
that are not serializable.
-        // Wrap the RenderedImage in SerializableRenderedImage to make it 
serializable.
-        if (!(raster.getRenderedImage() instanceof SerializableRenderedImage)) 
{
-            RenderedImage renderedImage = new SerializableRenderedImage(
-                    raster.getRenderedImage(),
-                    false,
-                    null,
-                    "gzip",
-                    null,
-                    null
-            );
-            raster = new GridCoverageFactory().create(
-                    raster.getName(),
-                    renderedImage,
-                    raster.getGridGeometry(),
-                    raster.getSampleDimensions(),
-                    null,
-                    raster.getProperties()
-            );
+        // Wrap the RenderedImage in DeepCopiedRenderedImage to make it 
serializable.
+        RenderedImage deepCopiedRenderedImage = null;
+        RenderedImage renderedImage = raster.getRenderedImage();
+        while (renderedImage instanceof RenderedImageAdapter) {
+            renderedImage = ((RenderedImageAdapter) 
renderedImage).getWrappedImage();
         }
+        if (renderedImage instanceof DeepCopiedRenderedImage) {
+            deepCopiedRenderedImage = renderedImage;
+        } else {
+            deepCopiedRenderedImage = new 
DeepCopiedRenderedImage(raster.getRenderedImage());
+        }
+        raster = new GridCoverageFactory().create(
+                raster.getName(),
+                deepCopiedRenderedImage,
+                raster.getGridGeometry(),
+                raster.getSampleDimensions(),
+                null,
+                raster.getProperties());
+
+        // Set the serializedImage so that GridCoverage2D will serialize the 
DeepCopiedRenderedImage object
+        // we created above, rather than creating a SerializedRenderedImage 
and serialize it. The whole point
+        // of DeepCopiedRenderedImage of getting rid of 
SerializedRenderedImage, which is problematic.
+        try {
+            field.set(raster, deepCopiedRenderedImage);
+        } catch (IllegalAccessException e) {
+            throw new RuntimeException(e);
+        }
+
         try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
             try (ObjectOutputStream oos = new ObjectOutputStream(bos)) {
                 oos.writeObject(raster);
diff --git 
a/common/src/test/java/org/apache/sedona/common/raster/SerdeTest.java 
b/common/src/test/java/org/apache/sedona/common/raster/SerdeTest.java
index 1a13534b..55dc85f3 100644
--- a/common/src/test/java/org/apache/sedona/common/raster/SerdeTest.java
+++ b/common/src/test/java/org/apache/sedona/common/raster/SerdeTest.java
@@ -1,24 +1,36 @@
-/**
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- * <p>
- * http://www.apache.org/licenses/LICENSE-2.0
- * <p>
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
  */
 package org.apache.sedona.common.raster;
 
+import org.geotools.coverage.grid.GridCoordinates2D;
 import org.geotools.coverage.grid.GridCoverage2D;
+import org.geotools.geometry.DirectPosition2D;
+import org.geotools.referencing.CRS;
+import org.junit.Assert;
 import org.junit.Test;
+import org.opengis.geometry.DirectPosition;
+import org.opengis.geometry.Envelope;
+import org.opengis.referencing.crs.CoordinateReferenceSystem;
+import org.opengis.referencing.operation.TransformException;
 
 import java.io.IOException;
 
-import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 
 public class SerdeTest extends RasterTestBase {
@@ -28,7 +40,10 @@ public class SerdeTest extends RasterTestBase {
         byte[] bytes = Serde.serialize(oneBandRaster);
         GridCoverage2D raster = Serde.deserialize(bytes);
         assertNotNull(raster);
-        assertEquals(1, raster.getNumSampleDimensions());
+        assertSameCoverage(oneBandRaster, raster);
+        bytes = Serde.serialize(raster);
+        raster = Serde.deserialize(bytes);
+        assertSameCoverage(oneBandRaster, raster);
     }
 
     @Test
@@ -36,6 +51,62 @@ public class SerdeTest extends RasterTestBase {
         byte[] bytes = Serde.serialize(this.multiBandRaster);
         GridCoverage2D raster = Serde.deserialize(bytes);
         assertNotNull(raster);
-        assertEquals(4, raster.getNumSampleDimensions());
+        assertSameCoverage(multiBandRaster, raster);
+        bytes = Serde.serialize(raster);
+        raster = Serde.deserialize(bytes);
+        assertSameCoverage(multiBandRaster, raster);
     }
-}
\ No newline at end of file
+
+    private void assertSameCoverage(GridCoverage2D expected, GridCoverage2D 
actual) {
+        Assert.assertEquals(expected.getNumSampleDimensions(), 
actual.getNumSampleDimensions());
+        Envelope expectedEnvelope = expected.getEnvelope();
+        Envelope actualEnvelope = actual.getEnvelope();
+        assertSameEnvelope(expectedEnvelope, actualEnvelope, 1e-6);
+        CoordinateReferenceSystem expectedCrs = 
expected.getCoordinateReferenceSystem();
+        CoordinateReferenceSystem actualCrs = 
actual.getCoordinateReferenceSystem();
+        Assert.assertTrue(CRS.equalsIgnoreMetadata(expectedCrs, actualCrs));
+        assertSameValues(expected, actual, 10);
+    }
+
+    private void assertSameEnvelope(Envelope expected, Envelope actual, double 
epsilon) {
+        Assert.assertEquals(expected.getMinimum(0), actual.getMinimum(0), 
epsilon);
+        Assert.assertEquals(expected.getMinimum(1), actual.getMinimum(1), 
epsilon);
+        Assert.assertEquals(expected.getMaximum(0), actual.getMaximum(0), 
epsilon);
+        Assert.assertEquals(expected.getMaximum(1), actual.getMaximum(1), 
epsilon);
+    }
+
+    private void assertSameValues(GridCoverage2D expected, GridCoverage2D 
actual, int density) {
+        Envelope expectedEnvelope = expected.getEnvelope();
+        double x0 = expectedEnvelope.getMinimum(0);
+        double y0 = expectedEnvelope.getMinimum(1);
+        double xStep = (expectedEnvelope.getMaximum(0) - x0) / density;
+        double yStep = (expectedEnvelope.getMaximum(1) - y0) / density;
+        double[] expectedValues = new 
double[expected.getNumSampleDimensions()];
+        double[] actualValues = new double[expected.getNumSampleDimensions()];
+        int sampledPoints = 0;
+        for (int i = 0; i < density; i++) {
+            for (int j = 0; j < density; j++) {
+                double x = x0 + j * xStep;
+                double y = y0 + i * yStep;
+                DirectPosition position = new DirectPosition2D(x, y);
+                try {
+                    GridCoordinates2D gridPosition = 
expected.getGridGeometry().worldToGrid(position);
+                    if (Double.isNaN(gridPosition.getX()) || 
Double.isNaN(gridPosition.getY())) {
+                        // This position is outside the coverage
+                        continue;
+                    }
+                    expected.evaluate(position, expectedValues);
+                    actual.evaluate(position, actualValues);
+                    Assert.assertEquals(expectedValues.length, 
actualValues.length);
+                    for (int k = 0; k < expectedValues.length; k++) {
+                        Assert.assertEquals(expectedValues[k], 
actualValues[k], 1e-6);
+                    }
+                    sampledPoints += 1;
+                } catch (TransformException e) {
+                    throw new RuntimeException("Failed to convert world 
coordinate to grid coordinate", e);
+                }
+            }
+        }
+        Assert.assertTrue(sampledPoints > density * density / 2);
+    }
+}

Reply via email to