This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch osm-reader
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 1a3bfa4bf035d6491a25846132b10aa9696835b0
Author: pawelkocinski <[email protected]>
AuthorDate: Thu Nov 28 15:09:28 2024 +0100

    Add initial version of osm pbf file reader.
---
 .../sql/datasources/osmpbf/OsmPbfReader.java       |  26 +++++-
 .../osmpbf/features/RelationParser.java            |  77 ++++++++++++++++
 .../sql/datasources/osmpbf/model/OsmPbfRecord.java |   9 ++
 .../sql/datasources/osmpbf/model/OsmRelation.java  |  35 +++++++
 .../sql/datasources/osmpbf/model/RelationType.java |  20 ++++
 .../datasources/osmpbf/OsmPbfPartitionReader.scala | 101 ++++++++++++++-------
 .../sql/datasources/osmpbf/OsmPbfScanBuilder.scala |   6 +-
 .../sql/datasources/osmpbf/OsmPbfTable.scala       |  26 ++++--
 .../org/apache/sedona/sql/OsmNodeReaderTest.scala  |   6 +-
 9 files changed, 257 insertions(+), 49 deletions(-)

diff --git 
a/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/OsmPbfReader.java
 
b/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/OsmPbfReader.java
index 5909f3be2b..dcda4ca375 100644
--- 
a/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/OsmPbfReader.java
+++ 
b/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/OsmPbfReader.java
@@ -2,14 +2,17 @@ package org.apache.sedona.sql.datasources.osmpbf;
 
 import java.io.ByteArrayInputStream;
 import java.io.DataInputStream;
+import java.io.File;
 import java.io.FileInputStream;
 import java.io.IOException;
+import java.io.RandomAccessFile;
 import java.util.ArrayList;
 import java.util.zip.DataFormatException;
 import java.util.zip.Inflater;
 import org.apache.sedona.sql.datasources.osmpbf.build.Fileformat;
 import org.apache.sedona.sql.datasources.osmpbf.build.Osmformat;
 import org.apache.sedona.sql.datasources.osmpbf.features.DenseNodeParser;
+import org.apache.sedona.sql.datasources.osmpbf.features.RelationParser;
 import org.apache.sedona.sql.datasources.osmpbf.features.WayParser;
 import org.apache.sedona.sql.datasources.osmpbf.model.OsmPbfRecord;
 
@@ -28,10 +31,18 @@ public class OsmPbfReader {
     int recordIndex = 0;
 
     public OsmPbfReader(OsmPbfOptions options) throws IOException {
-        this.stream = new FileInputStream(options.inputPath);
-        stream.skip(options.startOffset);
+        long currentTime = System.currentTimeMillis();
+
+        RandomAccessFile randomAccessFile = new RandomAccessFile(new 
File(options.inputPath), "r");
+        randomAccessFile.seek(options.startOffset);
+
+        this.pbfStream = new DataInputStream(
+                new FileInputStream(randomAccessFile.getFD())
+        );
+        this.stream = new FileInputStream(randomAccessFile.getFD());
+
+        System.out.println("Time to read file: " + (System.currentTimeMillis() 
- currentTime) + "for file offset " + options.startOffset + " - " + 
options.endOffset);
 
-        this.pbfStream = new DataInputStream(stream);
         endOffset = options.endOffset;
         startOffset = options.startOffset;
     }
@@ -98,16 +109,21 @@ public class OsmPbfReader {
         int granularity = pb.getGranularity();
         Osmformat.StringTable stringTable = pb.getStringtable();
 
+        records = new ArrayList<>();
+
         for (Osmformat.PrimitiveGroup group : pb.getPrimitivegroupList()) {
             OsmDataType type = FeatureParser.getType(group);
 
             if (type == OsmDataType.DENSE_NODE) {
                 DenseNodeParser denseNodeParser = new 
DenseNodeParser(granularity, latOffset, lonOffset);
-                records = denseNodeParser.parse(group.getDense(), stringTable);
+                records.addAll(denseNodeParser.parse(group.getDense(), 
stringTable));
                 recordIndex = 0;
             } else if (type == OsmDataType.WAY) {
                 WayParser wayParser = new WayParser();
-                records = wayParser.parse(group.getWaysList(), stringTable);
+                records.addAll(wayParser.parse(group.getWaysList(), 
stringTable));
+                recordIndex = 0;
+            } else if (type == OsmDataType.RELATION) {
+                records.addAll(RelationParser.parse(group.getRelationsList(), 
stringTable));
                 recordIndex = 0;
             }
              else {
diff --git 
a/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/features/RelationParser.java
 
b/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/features/RelationParser.java
new file mode 100644
index 0000000000..2b2adb6984
--- /dev/null
+++ 
b/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/features/RelationParser.java
@@ -0,0 +1,77 @@
+package org.apache.sedona.sql.datasources.osmpbf.features;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import org.apache.sedona.sql.datasources.osmpbf.build.Osmformat;
+import org.apache.sedona.sql.datasources.osmpbf.model.OsmPbfRecord;
+import org.apache.sedona.sql.datasources.osmpbf.model.OsmRelation;
+import org.apache.sedona.sql.datasources.osmpbf.model.RelationType;
+
+public class RelationParser {
+    public static ArrayList<OsmPbfRecord> parse(List<Osmformat.Relation> 
relations, Osmformat.StringTable stringTable) {
+        ArrayList<OsmPbfRecord> records = new ArrayList<>();
+        if (relations == null || relations.isEmpty()) {
+            return records;
+        }
+
+        for (Osmformat.Relation relation : relations) {
+            List<Long> memberIds = resolveMemberIds(relation);
+            List<String> memberTypes = resolveTypes(relation);
+
+            HashMap<String, String> tags = resolveTags(relation, stringTable);
+
+            OsmRelation relationObj = new OsmRelation(
+                    relation.getId(),
+                    memberIds,
+                    tags,
+                    memberTypes
+            );
+
+            records.add(new OsmPbfRecord(relationObj));
+        }
+
+        return records;
+    }
+
+    public static List<Long> resolveMemberIds(Osmformat.Relation relation) {
+        List<Long> memberIds = new ArrayList<>();
+
+        if (relation.getMemidsCount() != 0) {
+            long firstId = relation.getMemids(0);
+            memberIds.add(firstId);
+
+            for (int i = 1; i < relation.getMemidsCount(); i++) {
+                memberIds.add(relation.getMemids(i) + firstId);
+            }
+        }
+
+        return memberIds;
+    }
+
+    public static List<String> resolveTypes(Osmformat.Relation relation) {
+        List<String> types = new ArrayList<>();
+
+        for (int i = 0; i < relation.getTypesCount(); i++) {
+            Osmformat.Relation.MemberType memberType = relation.getTypes(i);
+            types.add(RelationType.fromValue(memberType.getNumber()));
+        }
+
+        return types;
+    }
+
+    public static HashMap<String, String> resolveTags(Osmformat.Relation 
relation, Osmformat.StringTable stringTable) {
+        HashMap<String, String> tags = new HashMap<>();
+
+        for (int i = 0; i < relation.getKeysCount(); i++) {
+            int key = relation.getKeys(i);
+            int value = relation.getVals(i);
+
+            String keyString = stringTable.getS(key).toStringUtf8();
+            String valueString = stringTable.getS(value).toStringUtf8();
+            tags.put(keyString, valueString);
+        }
+
+        return tags;
+    }
+}
diff --git 
a/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/model/OsmPbfRecord.java
 
b/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/model/OsmPbfRecord.java
index ae775d1086..27fb5644ba 100644
--- 
a/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/model/OsmPbfRecord.java
+++ 
b/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/model/OsmPbfRecord.java
@@ -3,6 +3,7 @@ package org.apache.sedona.sql.datasources.osmpbf.model;
 public class OsmPbfRecord {
     OsmNode node;
     OsmWay way;
+    OsmRelation relation;
 
     public OsmPbfRecord(OsmNode node) {
         this.node = node;
@@ -12,6 +13,10 @@ public class OsmPbfRecord {
         this.way = way;
     }
 
+    public OsmPbfRecord(OsmRelation relation) {
+        this.relation = relation;
+    }
+
 
     public OsmPbfRecord() {
     }
@@ -23,4 +28,8 @@ public class OsmPbfRecord {
     public OsmWay getWay() {
         return way;
     }
+
+    public OsmRelation getRelation() {
+        return relation;
+    }
 }
diff --git 
a/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/model/OsmRelation.java
 
b/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/model/OsmRelation.java
new file mode 100644
index 0000000000..ddbee29c6c
--- /dev/null
+++ 
b/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/model/OsmRelation.java
@@ -0,0 +1,35 @@
+package org.apache.sedona.sql.datasources.osmpbf.model;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+public class OsmRelation {
+    private HashMap<String, String> tags;
+    private List<String> types;
+    private long id;
+    private List<Long> memberIds;
+
+    public OsmRelation(long id, List<Long> memberIds, HashMap<String, String> 
tags, List<String> types) {
+        this.id = id;
+        this.memberIds = memberIds;
+        this.tags = tags;
+        this.types = types;
+    }
+
+    public long getId() {
+        return id;
+    }
+
+    public List<Long> getMemberIds() {
+        return memberIds;
+    }
+
+    public List<String> getTypes() {
+        return types;
+    }
+
+    public HashMap<String, String> getTags() {
+        return tags;
+    }
+}
diff --git 
a/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/model/RelationType.java
 
b/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/model/RelationType.java
new file mode 100644
index 0000000000..c3e6918160
--- /dev/null
+++ 
b/spark/common/src/main/java/org/apache/sedona/sql/datasources/osmpbf/model/RelationType.java
@@ -0,0 +1,20 @@
+package org.apache.sedona.sql.datasources.osmpbf.model;
+
+public enum RelationType {
+    NODE,
+    WAY,
+    RELATION;
+
+    public static String fromValue(int number) {
+        switch (number) {
+            case 0:
+                return NODE.toString();
+            case 1:
+                return WAY.toString();
+            case 2:
+                return RELATION.toString();
+            default:
+                throw new IllegalArgumentException("Unknown relation type: " + 
number);
+        }
+    }
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfPartitionReader.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfPartitionReader.scala
index c1bf34ddd7..1b3edffb7d 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfPartitionReader.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfPartitionReader.scala
@@ -2,7 +2,7 @@ package org.apache.sedona.sql.datasources.osmpbf
 
 import org.apache.sedona.sql.datasources.geopackage.model.GeoPackageReadOptions
 import org.apache.sedona.sql.datasources.osmpbf.build.Osmformat.Way
-import org.apache.sedona.sql.datasources.osmpbf.model.OsmNode
+import org.apache.sedona.sql.datasources.osmpbf.model.{OsmNode, OsmPbfRecord}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData}
@@ -13,57 +13,97 @@ import org.apache.spark.util.SerializableConfiguration
 import java.io.File
 import java.util
 import scala.collection.immutable.Seq
+import scala.reflect.ClassTag
 
 case class OsmPbfPartitionReader(
   reader: OsmPbfReader
 ) extends PartitionReader[InternalRow] {
 
+  implicit val f1: () => Seq[Long] = () => Seq().map(Long.unbox)
+  implicit val f2: () => Seq[UTF8String] = () => Seq()
+  implicit val t1 = (x: java.lang.Long) => x.toLong
+  implicit val t2 = (x: String) => UTF8String.fromString(x)
+
   override def next(): Boolean = reader.next()
 
   override def get(): InternalRow = {
-    val pbfData = reader.get()
+    val record = reader.get()
+    if (record == null) {
+      return InternalRow.fromSeq(Seq(null, null, null))
+    }
+
+    InternalRow.fromSeq(Seq(
+      resolveNode(record),
+      resolveWay(record),
+      resolveRelation(record)
+    ))
+  }
 
-    val node = pbfData.getNode
-    if (node == null && pbfData.getWay == null) {
-      return InternalRow.fromSeq(Seq(null, null, null, null))
+  private def resolveRelation(record: OsmPbfRecord): InternalRow = {
+    val relation = record.getRelation
+
+    if (relation == null) {
+      return null
     }
 
-    if (pbfData.getWay != null) {
-      val way = pbfData.getWay
+    val tags = transformTags(relation.getTags)
+
+    InternalRow.fromSeq(Seq(
+      relation.getId,
+      transformList[java.lang.Long, Long](relation.getMemberIds),
+      transformList[java.lang.String, UTF8String](relation.getTypes),
+      tags
+    ))
+  }
+
+  private def resolveWay(record: OsmPbfRecord): InternalRow = {
+    val way = record.getWay
 
-      var refs = Seq().map(Long.unbox)
+    if (way == null) {
+      return null
+    }
 
-      pbfData.getWay.getRefs.forEach(r => {
-        refs :+= r.toLong
-      })
+    val tags = transformTags(way.getTags)
 
-      val refsArray = ArrayData.toArrayData(Array(
-        refs:_*
-      ))
+    InternalRow.fromSeq(Seq(
+      way.getId,
+      transformList[java.lang.Long, Long](way.getRefs),
+      tags
+    ))
+  }
 
-      val refsMap = transformTags(pbfData.getWay.getTags)
+  private def resolveNode(record: OsmPbfRecord): InternalRow = {
+    val node = record.getNode
 
-      return InternalRow.fromSeq(Seq(
-        null, null, null,
-        InternalRow.fromSeq(
-          Seq(
-            way.getId,
-            refsArray,
-            refsMap
-          )
-        )
-      ))
+    if (node == null) {
+      return null
     }
 
     InternalRow.fromSeq(Seq(
-      UTF8String.fromString(pbfData.getNode.getId.toString),
+      UTF8String.fromString(record.getNode.getId.toString),
       InternalRow.fromSeq(Seq(
-        pbfData.getNode.getLatitude,
-        pbfData.getNode.getLongitude,
+        node.getLatitude,
+        node.getLongitude,
       )),
-      transformTags(pbfData.getNode.getTags),
-      null,
+      transformTags(node.getTags)
+    ))
+  }
+
+  def transformList[T <: java.lang.Object, R: ClassTag](data: util.List[T])(
+    implicit f: () => Seq[R],
+    t: T => R,
+  ): ArrayData = {
+    var refs = f()
+
+    data.forEach(r => {
+      refs :+= t(r)
+    })
+
+    val refsArray = ArrayData.toArrayData(Array(
+      refs:_*
     ))
+
+    refsArray
   }
 
   def transformTags(tags: util.Map[String, String]): ArrayBasedMapData = {
@@ -87,5 +127,4 @@ case class OsmPbfPartitionReader(
   }
 
   override def close(): Unit = {}
-
 }
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfScanBuilder.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfScanBuilder.scala
index eb3502caa7..2867218939 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfScanBuilder.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfScanBuilder.scala
@@ -16,7 +16,7 @@ case class OsmPbfScanBuilder(
   path: String,
 ) extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
 
-  val factor = 12
+  val factor = 100
 
   override def build(): Scan = {
     val paths = fileIndex.allFiles().map(_.getPath.toString)
@@ -27,10 +27,10 @@ case class OsmPbfScanBuilder(
 
     val length = stream.available
 
-    val chunk = length/factor
+    val chunk = length/(factor*10)
 
     val chunks = (0 until factor).map { i =>
-      (i * chunk, (i+1) * chunk)
+      ((i*10 * ()) * chunk, ((i+1)/10) * chunk)
     }
 
     val fileIndexAdjusted = new InMemoryFileIndex(
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfTable.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfTable.scala
index da97f01805..8d443170a6 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfTable.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/datasources/osmpbf/OsmPbfTable.scala
@@ -18,21 +18,33 @@ case class OsmPbfTable(
 
   override def inferSchema(files: Seq[FileStatus]): Option[StructType] = {
     Some(StructType(Seq(
-      StructField("id", StringType, nullable = false),
-      StructField("location", StructType(
+      StructField("node", StructType(
         Seq(
-          StructField("longitude", DoubleType, nullable = false),
-          StructField("latitude", DoubleType, nullable = false),
+          StructField("id", LongType, nullable = false),
+          StructField("location", StructType(
+            Seq(
+              StructField("longitude", DoubleType, nullable = false),
+              StructField("latitude", DoubleType, nullable = false),
+            )
+          ), nullable = false),
+          StructField("tags", MapType(StringType, StringType, 
valueContainsNull = true), nullable = true),
         )
-      ), nullable = false),
-      StructField("tags", MapType(StringType, StringType, valueContainsNull = 
true), nullable = true),
-      StructField("ways", StructType(
+      ), nullable = true),
+      StructField("way", StructType(
         Seq(
           StructField("id", LongType, nullable = false),
           StructField("refs", ArrayType(LongType), nullable = false),
           StructField("tags", MapType(StringType, StringType, 
valueContainsNull = true), nullable = true),
         )
       ), nullable = true),
+      StructField("relation", StructType(
+        Seq(
+          StructField("id", LongType, nullable = false),
+          StructField("member_ids", ArrayType(LongType), nullable = false),
+          StructField("types", ArrayType(StringType), nullable = false),
+          StructField("tags", MapType(StringType, StringType, 
valueContainsNull = true), nullable = true),
+        )
+      ), nullable = true),
     )
     ))
   }
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/OsmNodeReaderTest.scala 
b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/OsmNodeReaderTest.scala
index ea4e09c33d..08609d0e5b 100644
--- 
a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/OsmNodeReaderTest.scala
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/OsmNodeReaderTest.scala
@@ -9,13 +9,13 @@ class OsmNodeReaderTest extends TestBaseScala with Matchers {
       sparkSession
         .read
         .format("osmpbf")
-        
.load("/Users/pawelkocinski/Desktop/projects/osm-data-reader/src/main/resources/lubuskie-latest.osm.pbf")
+        
.load("/Users/pawelkocinski/Desktop/projects/osm-data-reader/src/main/resources/poland-latest.osm.pbf")
 //        
.load("/Users/pawelkocinski/Desktop/projects/osm-data-reader/src/main/resources/poland-latest.osm.pbf")
         .createOrReplaceTempView("osm")
 
       sparkSession.sql("SELECT * FROM osm")
-        .where("ways is not null and size(ways.tags) > 0")
-        .show(5, false)
+//        .where("relation is not null AND size(relation.tags) > 0")
+        .count()
     }
   }
 

Reply via email to