carloea2 commented on code in PR #4136:
URL: https://github.com/apache/texera/pull/4136#discussion_r2656970594


##########
file-service/src/main/scala/org/apache/texera/service/resource/DatasetResource.scala:
##########
@@ -1372,4 +1405,378 @@ class DatasetResource {
         Right(response)
     }
   }
+
+  // === Multipart helpers ===
+
+  private def getDatasetBy(ownerEmail: String, datasetName: String) = {
+    val dataset = context
+      .select(DATASET.fields: _*)
+      .from(DATASET)
+      .leftJoin(USER)
+      .on(USER.UID.eq(DATASET.OWNER_UID))
+      .where(USER.EMAIL.eq(ownerEmail))
+      .and(DATASET.NAME.eq(datasetName))
+      .fetchOneInto(classOf[Dataset])
+    if (dataset == null) {
+      throw new BadRequestException("Dataset not found")
+    }
+    dataset
+  }
+
+  private def validateFilePathOrThrow(filePath: String): String = {
+    val p = Option(filePath).getOrElse("")
+    val s = p.replace("\\", "/")
+    if (
+      p.isEmpty ||
+      s.startsWith("/") ||
+      s.split("/").exists(seg => seg == "." || seg == "..") ||
+      s.exists(ch => ch == 0.toChar || ch < 0x20.toChar || ch == 0x7f.toChar)
+    ) throw new BadRequestException("Invalid filePath")
+    p
+  }
+
+  private def initMultipartUpload(
+      did: Integer,
+      encodedFilePath: String,
+      numParts: Optional[Integer],
+      uid: Integer
+  ): Response = {
+
+    withTransaction(context) { ctx =>
+      if (!userHasWriteAccess(ctx, did, uid)) {
+        throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE)
+      }
+
+      val dataset = getDatasetByID(ctx, did)
+      val repositoryName = dataset.getRepositoryName
+
+      val filePath =
+        validateFilePathOrThrow(URLDecoder.decode(encodedFilePath, 
StandardCharsets.UTF_8.name()))
+
+      val numPartsValue = numParts.toScala.getOrElse {
+        throw new BadRequestException("numParts is required for 
initialization")
+      }
+      if (numPartsValue < 1 || numPartsValue > 
MAXIMUM_NUM_OF_MULTIPART_S3_PARTS) {
+        throw new BadRequestException(
+          "numParts must be between 1 and " + MAXIMUM_NUM_OF_MULTIPART_S3_PARTS
+        )
+      }
+
+      // Reject if a session already exists
+      val exists = ctx.fetchExists(
+        ctx
+          .selectOne()
+          .from(DATASET_UPLOAD_SESSION)
+          .where(
+            DATASET_UPLOAD_SESSION.UID
+              .eq(uid)
+              .and(DATASET_UPLOAD_SESSION.DID.eq(did))
+              .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath))
+          )
+      )
+      if (exists) {
+        throw new WebApplicationException(
+          "Upload already in progress for this filePath",
+          Response.Status.CONFLICT
+        )
+      }
+
+      val presign = LakeFSStorageClient.initiatePresignedMultipartUploads(
+        repositoryName,
+        filePath,
+        numPartsValue
+      )
+
+      val uploadIdStr = presign.getUploadId
+      val physicalAddr = presign.getPhysicalAddress
+
+      // If anything fails after this point, abort LakeFS multipart
+      try {
+        val rowsInserted = ctx
+          .insertInto(DATASET_UPLOAD_SESSION)
+          .set(DATASET_UPLOAD_SESSION.FILE_PATH, filePath)
+          .set(DATASET_UPLOAD_SESSION.DID, did)
+          .set(DATASET_UPLOAD_SESSION.UID, uid)
+          .set(DATASET_UPLOAD_SESSION.UPLOAD_ID, uploadIdStr)
+          .set(DATASET_UPLOAD_SESSION.PHYSICAL_ADDRESS, physicalAddr)
+          .set(DATASET_UPLOAD_SESSION.NUM_PARTS_REQUESTED, numPartsValue)
+          .onDuplicateKeyIgnore()
+          .execute()
+
+        if (rowsInserted != 1) {
+          LakeFSStorageClient.abortPresignedMultipartUploads(
+            repositoryName,
+            filePath,
+            uploadIdStr,
+            physicalAddr
+          )
+          throw new WebApplicationException(
+            "Upload already in progress for this filePath",
+            Response.Status.CONFLICT
+          )
+        }
+
+        // Pre-create part rows 1..numPartsValue with empty ETag.
+        // This makes per-part locking cheap and deterministic.
+
+        val gs = DSL.generateSeries(1, numPartsValue).asTable("gs", "pn")
+        val PN = gs.field("pn", classOf[Integer])
+
+        ctx
+          .insertInto(
+            DATASET_UPLOAD_SESSION_PART,
+            DATASET_UPLOAD_SESSION_PART.UPLOAD_ID,
+            DATASET_UPLOAD_SESSION_PART.PART_NUMBER,
+            DATASET_UPLOAD_SESSION_PART.ETAG
+          )
+          .select(
+            ctx
+              .select(
+                inl(uploadIdStr),
+                PN,
+                inl("") // placeholder empty etag
+              )
+              .from(gs)
+          )
+          .execute()
+
+        Response.ok().build()
+      } catch {
+        case e: Exception =>
+          // rollback will remove session + parts rows; we still must abort 
LakeFS
+          try {
+            LakeFSStorageClient.abortPresignedMultipartUploads(
+              repositoryName,
+              filePath,
+              uploadIdStr,
+              physicalAddr
+            )
+          } catch { case _: Throwable => () }
+          throw e
+      }
+    }
+  }
+
+  private def finishMultipartUpload(
+      did: Integer,
+      encodedFilePath: String,
+      uid: Int
+  ): Response = {
+
+    val filePath = validateFilePathOrThrow(
+      URLDecoder.decode(encodedFilePath, StandardCharsets.UTF_8.name())
+    )
+
+    withTransaction(context) { ctx =>
+      if (!userHasWriteAccess(ctx, did, uid)) {
+        throw new ForbiddenException(ERR_USER_HAS_NO_ACCESS_TO_DATASET_MESSAGE)
+      }
+
+      val dataset = getDatasetByID(ctx, did)
+
+      // Lock the session so abort/finish don't race each other
+      val session =
+        try {
+          ctx
+            .selectFrom(DATASET_UPLOAD_SESSION)
+            .where(
+              DATASET_UPLOAD_SESSION.UID
+                .eq(uid)
+                .and(DATASET_UPLOAD_SESSION.DID.eq(did))
+                .and(DATASET_UPLOAD_SESSION.FILE_PATH.eq(filePath))
+            )
+            .forUpdate()
+            .noWait()
+            .fetchOne()
+        } catch {
+          case e: DataAccessException
+              if Option(e.getCause)
+                .collect { case s: SQLException => s.getSQLState }
+                .contains("55P03") =>
+            throw new WebApplicationException(
+              "Upload is already being finalized/aborted",
+              Response.Status.CONFLICT
+            )
+        }
+
+      if (session == null) {
+        throw new NotFoundException("Upload session not found or already 
finalized")
+      }
+
+      val uploadId = session.getUploadId
+      val expectedParts = session.getNumPartsRequested
+
+      val physicalAddr = 
Option(session.getPhysicalAddress).map(_.trim).getOrElse("")
+      if (physicalAddr.isEmpty) {
+        throw new WebApplicationException(
+          "Upload session is missing physicalAddress. Re-init the upload.",
+          Response.Status.INTERNAL_SERVER_ERROR
+        )
+      }
+
+      val total = DSL.count()
+      val done =
+        DSL
+          .count()
+          .filterWhere(DATASET_UPLOAD_SESSION_PART.ETAG.ne(""))
+          .as("done")
+
+      val agg = ctx
+        .select(total.as("total"), done)
+        .from(DATASET_UPLOAD_SESSION_PART)
+        .where(DATASET_UPLOAD_SESSION_PART.UPLOAD_ID.eq(uploadId))
+        .fetchOne()
+
+      val totalCnt = agg.get("total", classOf[java.lang.Integer]).intValue()
+      val doneCnt = agg.get("done", classOf[java.lang.Integer]).intValue()
+
+      if (totalCnt != expectedParts.toLong) {

Review Comment:
   Fixed it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to