This is an automated email from the ASF dual-hosted git repository.

yecol pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-graphar.git


The following commit(s) were added to refs/heads/main by this push:
     new 1cca4035 feat(spark): read and write label chunk in spark (#718)
1cca4035 is described below

commit 1cca4035f271fd117406d1d555a5ce1d3cd7ad51
Author: Xiaokang Yang <[email protected]>
AuthorDate: Mon Jul 14 14:08:22 2025 +0800

    feat(spark): read and write label chunk in spark (#718)
    
    * feat(spark): read multi-label from parquet files
    
    * feat(spark): write multi-label to parquet files
    
    * add test
---
 .../java/org/apache/graphar/GeneralParams.java     |   1 +
 .../main/scala/org/apache/graphar/VertexInfo.scala |  12 ++-
 .../org/apache/graphar/graph/GraphReader.scala     |  37 ++++++-
 .../org/apache/graphar/reader/VertexReader.scala   |  44 +++++++-
 .../org/apache/graphar/writer/VertexWriter.scala   |  45 +++++++-
 .../apache/graphar/TestLabelReaderAndWriter.scala  | 117 +++++++++++++++++++++
 6 files changed, 241 insertions(+), 15 deletions(-)

diff --git 
a/maven-projects/spark/graphar/src/main/java/org/apache/graphar/GeneralParams.java
 
b/maven-projects/spark/graphar/src/main/java/org/apache/graphar/GeneralParams.java
index 8e47d1a4..91ae6e80 100644
--- 
a/maven-projects/spark/graphar/src/main/java/org/apache/graphar/GeneralParams.java
+++ 
b/maven-projects/spark/graphar/src/main/java/org/apache/graphar/GeneralParams.java
@@ -25,6 +25,7 @@ import org.apache.spark.storage.StorageLevel;
 public class GeneralParams {
     // column name
     public static final String vertexIndexCol = "_graphArVertexIndex";
+    public static final String kLabelCol = ":LABEL";
     public static final String srcIndexCol = "_graphArSrcIndex";
     public static final String dstIndexCol = "_graphArDstIndex";
     public static final String offsetCol = "_graphArOffset";
diff --git 
a/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/VertexInfo.scala
 
b/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/VertexInfo.scala
index 8c8b1e36..f6e5e935 100644
--- 
a/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/VertexInfo.scala
+++ 
b/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/VertexInfo.scala
@@ -20,18 +20,22 @@
 package org.apache.graphar
 
 import org.apache.hadoop.fs.Path
-import org.apache.spark.sql.{SparkSession}
-import org.yaml.snakeyaml.{Yaml, DumperOptions}
+import org.apache.spark.sql.SparkSession
+import org.yaml.snakeyaml.{DumperOptions, Yaml}
 import org.yaml.snakeyaml.constructor.Constructor
+
 import scala.beans.BeanProperty
 import org.yaml.snakeyaml.LoaderOptions
 
+import java.util
+
 /** VertexInfo is a class to store the vertex meta information. */
 class VertexInfo() {
   @BeanProperty var `type`: String = ""
   @BeanProperty var chunk_size: Long = 0
   @BeanProperty var prefix: String = ""
   @BeanProperty var property_groups = new java.util.ArrayList[PropertyGroup]()
+  @BeanProperty var labels = new java.util.ArrayList[String]()
   @BeanProperty var version: String = ""
 
   /**
@@ -284,6 +288,10 @@ class VertexInfo() {
   def dump(): String = {
     val data = new java.util.HashMap[String, Object]()
     data.put("type", `type`)
+    val labels_num = labels.size()
+    if (labels_num > 0) {
+      data.put("labels", labels)
+    }
     data.put("chunk_size", new java.lang.Long(chunk_size))
     if (prefix != "") data.put("prefix", prefix)
     data.put("version", version)
diff --git 
a/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/graph/GraphReader.scala
 
b/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/graph/GraphReader.scala
index 248c9a65..c52a149b 100644
--- 
a/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/graph/GraphReader.scala
+++ 
b/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/graph/GraphReader.scala
@@ -19,9 +19,8 @@
 
 package org.apache.graphar.graph
 
-import org.apache.graphar.{GraphInfo, VertexInfo, EdgeInfo}
-import org.apache.graphar.reader.{VertexReader, EdgeReader}
-
+import org.apache.graphar.{EdgeInfo, GeneralParams, GraphInfo, VertexInfo}
+import org.apache.graphar.reader.{EdgeReader, VertexReader}
 import org.apache.spark.sql.{DataFrame, SparkSession}
 
 /**
@@ -48,13 +47,41 @@ object GraphReader {
   ): Map[String, DataFrame] = {
     val vertex_dataframes: Map[String, DataFrame] = vertexInfos.map {
       case (vertex_type, vertexInfo) => {
-        val reader = new VertexReader(prefix, vertexInfo, spark)
-        (vertex_type, reader.readAllVertexPropertyGroups())
+        (vertex_type, readVertexWithLabels(prefix, vertexInfo, spark))
       }
     }
     return vertex_dataframes
   }
 
+  /**
+   * Loads the vertex chunks as DataFrame with the vertex infos.
+   *
+   * @param prefix
+   *   The absolute prefix.
+   * @param vertexInfos
+   *   The map of (vertex type -> VertexInfo) for the graph.
+   * @param spark
+   *   The Spark session for the reading.
+   * @return
+   *   The map of (vertex type -> DataFrame)
+   */
+  def readVertexWithLabels(
+      prefix: String,
+      vertexInfo: VertexInfo,
+      spark: SparkSession
+  ): DataFrame = {
+    val reader = new VertexReader(prefix, vertexInfo, spark)
+    val frame = reader.readAllVertexPropertyGroups()
+    if (vertexInfo.labels.isEmpty) {
+      return frame
+    }
+    val label_frame = reader.readVertexLabels()
+    if (label_frame.isEmpty) {
+      return frame
+    }
+    frame.join(label_frame, GeneralParams.vertexIndexCol, "left")
+  }
+
   /**
    * Loads the edge chunks as DataFrame with the edge infos.
    *
diff --git 
a/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/reader/VertexReader.scala
 
b/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/reader/VertexReader.scala
index fae0505c..c9527e39 100644
--- 
a/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/reader/VertexReader.scala
+++ 
b/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/reader/VertexReader.scala
@@ -19,13 +19,16 @@
 
 package org.apache.graphar.reader
 
-import org.apache.graphar.util.{IndexGenerator, DataFrameConcat}
-import org.apache.graphar.{VertexInfo, PropertyGroup, GeneralParams}
+import org.apache.graphar.util.{DataFrameConcat, IndexGenerator}
+import org.apache.graphar.{GeneralParams, PropertyGroup, VertexInfo}
 import org.apache.graphar.util.FileSystem
-
-import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.functions.{col, struct, udf}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
 import org.apache.spark.sql.types._
 
+import scala.collection.convert.ImplicitConversions.`list asScalaBuffer`
+import scala.collection.mutable.ListBuffer
+
 /**
  * Reader for vertex chunks.
  *
@@ -154,4 +157,37 @@ class VertexReader(
     val property_groups = vertexInfo.getProperty_groups()
     return readMultipleVertexPropertyGroups(property_groups)
   }
+
+  def readVertexLabels(): DataFrame = {
+    if (vertexInfo.labels.size() > 0) {
+      val labels_list = vertexInfo.labels.toSeq
+      val parquet_read_path = prefix + vertexInfo.prefix + "/labels"
+      val rdd = spark.read
+        .option("fileFormat", "parquet")
+        .option("header", "true")
+        .format("org.apache.graphar.datasources.GarDataSource")
+        .load(parquet_read_path)
+        .rdd
+        .map { row =>
+          val labelsBuffer = ListBuffer[String]()
+          for ((colName, index) <- labels_list.zipWithIndex) {
+            if (row.getBoolean(index)) {
+              labelsBuffer += colName
+            }
+          }
+          Row.apply(labelsBuffer.toSeq)
+        }
+        // TODO(yangxk) read the vertexIndexCol from the file.
+        .zipWithIndex()
+        .map { case (row, idx) => Row.fromSeq(row.toSeq :+ idx) }
+      val schema_array = Array(
+        StructField(GeneralParams.kLabelCol, ArrayType(StringType)),
+        StructField(GeneralParams.vertexIndexCol, LongType)
+      )
+      val schema = StructType(schema_array)
+      spark.createDataFrame(rdd, schema)
+    } else {
+      return spark.emptyDataFrame
+    }
+  }
 }
diff --git 
a/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/writer/VertexWriter.scala
 
b/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/writer/VertexWriter.scala
index 28d62b48..255c702b 100644
--- 
a/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/writer/VertexWriter.scala
+++ 
b/maven-projects/spark/graphar/src/main/scala/org/apache/graphar/writer/VertexWriter.scala
@@ -19,14 +19,14 @@
 
 package org.apache.graphar.writer
 
-import org.apache.graphar.util.{FileSystem, ChunkPartitioner, IndexGenerator}
-import org.apache.graphar.{GeneralParams, VertexInfo, PropertyGroup}
-
+import org.apache.graphar.util.{ChunkPartitioner, FileSystem, IndexGenerator}
+import org.apache.graphar.{GeneralParams, PropertyGroup, VertexInfo}
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.types.{LongType, StructField}
 
+import scala.collection.convert.ImplicitConversions.`collection 
AsScalaIterable`
 import scala.collection.mutable.ArrayBuffer
 
 /** Helper object for VertexWriter class. */
@@ -157,6 +157,43 @@ class VertexWriter(
     }
   }
 
+  def writeVertexLabels(): Unit = {
+    if (vertexInfo.labels.isEmpty) {
+      throw new IllegalArgumentException(
+        "vertex does not have labels."
+      )
+    }
+
+    // write out the chunks
+    val output_prefix = prefix + vertexInfo.prefix + "/labels"
+    val labels_list = vertexInfo.labels.toSeq
+    val label_num = labels_list.length
+    val labels_list_rdd =
+      chunks.select(col(GeneralParams.kLabelCol)).rdd.map { row =>
+        val labels = row.getSeq(0)
+        val bools = new Array[Boolean](label_num)
+        var i = 0
+        while (i < label_num) {
+          bools(i) = labels.contains(labels_list(i))
+          i += 1
+        }
+        Row.fromSeq(bools)
+      }
+    val schema = StructType(
+      labels_list.map(label =>
+        StructField(label, BooleanType, nullable = false)
+      )
+    )
+    val labelDf = spark.createDataFrame(labels_list_rdd, schema)
+    FileSystem.writeDataFrame(
+      labelDf,
+      "parquet",
+      output_prefix,
+      None,
+      None
+    )
+  }
+
   override def finalize(): Unit = {
     chunks.unpersist()
   }
diff --git 
a/maven-projects/spark/graphar/src/test/scala/org/apache/graphar/TestLabelReaderAndWriter.scala
 
b/maven-projects/spark/graphar/src/test/scala/org/apache/graphar/TestLabelReaderAndWriter.scala
new file mode 100644
index 00000000..00211b27
--- /dev/null
+++ 
b/maven-projects/spark/graphar/src/test/scala/org/apache/graphar/TestLabelReaderAndWriter.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.graphar
+
+import org.apache.graphar.graph.GraphReader
+import org.apache.graphar.graph.GraphReader.{
+  readVertexWithLabels,
+  readWithGraphInfo
+}
+import org.apache.graphar.reader.VertexReader
+import org.apache.graphar.writer.VertexWriter
+import org.apache.spark.sql.types.{ArrayType, StringType, StructField}
+
+class LabelReaderAndWriterSuite extends BaseTestSuite {
+
+  test("read vertices with labels from parquet file") {
+    val prefix = testData + "/ldbc/parquet/"
+    val vertex_yaml = prefix + "organisation.vertex.yml"
+    val vertex_info = VertexInfo.loadVertexInfo(vertex_yaml, spark)
+    val frame = readVertexWithLabels(prefix, vertex_info, spark)
+    assert(
+      frame.schema.contains(
+        StructField(GeneralParams.kLabelCol, ArrayType(StringType))
+      )
+    )
+    frame.select(GeneralParams.kLabelCol).show()
+  }
+
+  test("read vertices with labels with graph reader") {
+    val prefix = testData + "/ldbc/parquet/"
+    val vertex_yaml = prefix + "organisation.vertex.yml"
+    val vertex_info = VertexInfo.loadVertexInfo(vertex_yaml, spark)
+    val graph_info = new GraphInfo()
+    graph_info.addVertexInfo(vertex_info)
+    graph_info.setPrefix(prefix)
+    var vertex_edge_df_pair = readWithGraphInfo(graph_info, spark)
+    val vertex_dataframes = vertex_edge_df_pair._1
+    println(vertex_dataframes.keys)
+    val frame_organisation = vertex_dataframes("organisation")
+    assert(
+      frame_organisation.schema.contains(
+        StructField(GeneralParams.kLabelCol, ArrayType(StringType))
+      )
+    )
+    frame_organisation
+      .select(GeneralParams.vertexIndexCol, GeneralParams.kLabelCol)
+      .show()
+  }
+
+  test("write vertices with labels to parquet file") {
+    // read vertex DataFrame
+    val prefix = testData + "/ldbc/parquet/"
+    val vertex_yaml = prefix + "organisation.vertex.yml"
+    val vertex_info = VertexInfo.loadVertexInfo(vertex_yaml, spark)
+    val frame = readVertexWithLabels(prefix, vertex_info, spark)
+    assert(
+      frame.schema.contains(
+        StructField(GeneralParams.kLabelCol, ArrayType(StringType))
+      )
+    )
+    val output_prefix: String = "/tmp/"
+    val writer = new VertexWriter(output_prefix, vertex_info, frame)
+    writer.writeVertexLabels()
+    val reader = new VertexReader(output_prefix, vertex_info, spark)
+    val frame2 = reader.readVertexLabels().select(GeneralParams.kLabelCol)
+    frame2.show()
+    assert(
+      frame2.schema.contains(
+        StructField(GeneralParams.kLabelCol, ArrayType(StringType))
+      )
+    )
+    val frame1 = frame.select(GeneralParams.kLabelCol)
+    val diff1 = frame1.except(frame2)
+    val diff2 = frame2.except(frame1)
+    assert(diff1.isEmpty)
+    assert(diff2.isEmpty)
+  }
+
+  test("read vertices without labels") {
+    val prefix = testData + "/ldbc_sample/parquet/"
+    val vertex_yaml = prefix + "person.vertex.yml"
+    val vertex_info = VertexInfo.loadVertexInfo(vertex_yaml, spark)
+    val frame = readVertexWithLabels(prefix, vertex_info, spark)
+    frame.show()
+    assert(
+      !frame.schema.contains(
+        StructField(GeneralParams.kLabelCol, ArrayType(StringType))
+      )
+    )
+  }
+
+  test("add vertex labels dump") {
+    val prefix = testData + "/ldbc_sample/parquet/"
+    val vertex_yaml = prefix + "person.vertex.yml"
+    val vertex_info = VertexInfo.loadVertexInfo(vertex_yaml, spark)
+    vertex_info.labels.add("Employee")
+    println(vertex_info.dump())
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to