This is an automated email from the ASF dual-hosted git repository.

shenghang pushed a commit to branch dev
in repository https://gitbox.apache.org/repos/asf/seatunnel.git


The following commit(s) were added to refs/heads/dev by this push:
     new 123eedc0b3 [Fix][Connector-Jdbc]prevent duplicate XA XID in 
exactly-once writer and rollback prepared tx on begin failure (#10459)
123eedc0b3 is described below

commit 123eedc0b3ba77a976e1c718cae5ccc1f5dece1d
Author: yzeng1618 <[email protected]>
AuthorDate: Tue Mar 17 21:52:02 2026 +0800

    [Fix][Connector-Jdbc]prevent duplicate XA XID in exactly-once writer and 
rollback prepared tx on begin failure (#10459)
    
    Co-authored-by: zengyi <[email protected]>
---
 .../multitablesink/MultiTableSinkCommitter.java    |  16 +-
 .../MultiTableSinkCommitterTest.java               |  85 ++++++
 .../jdbc/sink/JdbcExactlyOnceSinkWriter.java       | 155 ++++++++--
 .../jdbc/sink/JdbcExactlyOnceSinkWriterTest.java   | 332 +++++++++++++++++++++
 4 files changed, 565 insertions(+), 23 deletions(-)

diff --git 
a/seatunnel-api/src/main/java/org/apache/seatunnel/api/sink/multitablesink/MultiTableSinkCommitter.java
 
b/seatunnel-api/src/main/java/org/apache/seatunnel/api/sink/multitablesink/MultiTableSinkCommitter.java
index 113e269fd0..f8c5152a4e 100644
--- 
a/seatunnel-api/src/main/java/org/apache/seatunnel/api/sink/multitablesink/MultiTableSinkCommitter.java
+++ 
b/seatunnel-api/src/main/java/org/apache/seatunnel/api/sink/multitablesink/MultiTableSinkCommitter.java
@@ -23,7 +23,6 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
 import java.util.stream.Collectors;
 
 public class MultiTableSinkCommitter implements 
SinkCommitter<MultiTableCommitInfo> {
@@ -67,12 +66,17 @@ public class MultiTableSinkCommitter implements 
SinkCommitter<MultiTableCommitIn
             if (sinkCommitter != null) {
                 List commitInfo =
                         commitInfos.stream()
-                                .map(
+                                .flatMap(
                                         multiTableCommitInfo ->
-                                                multiTableCommitInfo
-                                                        .getCommitInfo()
-                                                        .get(sinkIdentifier))
-                                .filter(Objects::nonNull)
+                                                
multiTableCommitInfo.getCommitInfo().entrySet()
+                                                        .stream()
+                                                        .filter(
+                                                                entry ->
+                                                                        
entry.getKey()
+                                                                               
 .getTableIdentifier()
+                                                                               
 .equals(
+                                                                               
         sinkIdentifier)))
+                                .map(Map.Entry::getValue)
                                 .collect(Collectors.toList());
                 sinkCommitter.abort(commitInfo);
             }
diff --git 
a/seatunnel-api/src/test/java/org/apache/seatunnel/api/sink/multitablesink/MultiTableSinkCommitterTest.java
 
b/seatunnel-api/src/test/java/org/apache/seatunnel/api/sink/multitablesink/MultiTableSinkCommitterTest.java
new file mode 100644
index 0000000000..cab521d304
--- /dev/null
+++ 
b/seatunnel-api/src/test/java/org/apache/seatunnel/api/sink/multitablesink/MultiTableSinkCommitterTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.api.sink.multitablesink;
+
+import org.apache.seatunnel.api.sink.SinkCommitter;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+class MultiTableSinkCommitterTest {
+
+    @Test
+    void testRouteByTableIdentifierForCommitAndAbort() throws IOException {
+        String table1 = "catalog.db.table1";
+        String table2 = "catalog.db.table2";
+
+        RecordingSinkCommitter table1Committer = new RecordingSinkCommitter();
+        RecordingSinkCommitter table2Committer = new RecordingSinkCommitter();
+
+        Map<String, SinkCommitter<?>> sinkCommitters = new HashMap<>();
+        sinkCommitters.put(table1, table1Committer);
+        sinkCommitters.put(table2, table2Committer);
+
+        MultiTableSinkCommitter multiTableSinkCommitter =
+                new MultiTableSinkCommitter(sinkCommitters);
+
+        MultiTableCommitInfo commitInfo1 = new MultiTableCommitInfo(new 
ConcurrentHashMap<>());
+        commitInfo1.getCommitInfo().put(SinkIdentifier.of(table1, 0), "t1-c0");
+        commitInfo1.getCommitInfo().put(SinkIdentifier.of(table2, 0), "t2-c0");
+
+        MultiTableCommitInfo commitInfo2 = new MultiTableCommitInfo(new 
ConcurrentHashMap<>());
+        commitInfo2.getCommitInfo().put(SinkIdentifier.of(table1, 1), "t1-c1");
+        commitInfo2.getCommitInfo().put(SinkIdentifier.of(table2, 1), "t2-c1");
+
+        List<MultiTableCommitInfo> allCommitInfos = Arrays.asList(commitInfo1, 
commitInfo2);
+
+        multiTableSinkCommitter.commit(allCommitInfos);
+        Assertions.assertIterableEquals(Arrays.asList("t1-c0", "t1-c1"), 
table1Committer.committed);
+        Assertions.assertIterableEquals(Arrays.asList("t2-c0", "t2-c1"), 
table2Committer.committed);
+
+        multiTableSinkCommitter.abort(allCommitInfos);
+        Assertions.assertIterableEquals(Arrays.asList("t1-c0", "t1-c1"), 
table1Committer.aborted);
+        Assertions.assertIterableEquals(Arrays.asList("t2-c0", "t2-c1"), 
table2Committer.aborted);
+    }
+
+    private static class RecordingSinkCommitter implements 
SinkCommitter<Object> {
+
+        private List<Object> committed = Collections.emptyList();
+        private List<Object> aborted = Collections.emptyList();
+
+        @Override
+        public List<Object> commit(List<Object> commitInfos) {
+            this.committed = commitInfos;
+            return Collections.emptyList();
+        }
+
+        @Override
+        public void abort(List<Object> commitInfos) {
+            this.aborted = commitInfos;
+        }
+    }
+}
diff --git 
a/seatunnel-connectors-v2/connector-jdbc/src/main/java/org/apache/seatunnel/connectors/seatunnel/jdbc/sink/JdbcExactlyOnceSinkWriter.java
 
b/seatunnel-connectors-v2/connector-jdbc/src/main/java/org/apache/seatunnel/connectors/seatunnel/jdbc/sink/JdbcExactlyOnceSinkWriter.java
index 2c617386c7..1c9be1f66f 100644
--- 
a/seatunnel-connectors-v2/connector-jdbc/src/main/java/org/apache/seatunnel/connectors/seatunnel/jdbc/sink/JdbcExactlyOnceSinkWriter.java
+++ 
b/seatunnel-connectors-v2/connector-jdbc/src/main/java/org/apache/seatunnel/connectors/seatunnel/jdbc/sink/JdbcExactlyOnceSinkWriter.java
@@ -29,8 +29,10 @@ import 
org.apache.seatunnel.common.exception.CommonErrorCodeDeprecated;
 import org.apache.seatunnel.connectors.seatunnel.jdbc.config.JdbcSinkConfig;
 import 
org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorErrorCode;
 import 
org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorException;
+import 
org.apache.seatunnel.connectors.seatunnel.jdbc.internal.JdbcOutputFormat;
 import 
org.apache.seatunnel.connectors.seatunnel.jdbc.internal.JdbcOutputFormatBuilder;
 import 
org.apache.seatunnel.connectors.seatunnel.jdbc.internal.dialect.JdbcDialect;
+import 
org.apache.seatunnel.connectors.seatunnel.jdbc.internal.executor.JdbcBatchStatementExecutor;
 import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.xa.XaFacade;
 import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.xa.XaGroupOps;
 import 
org.apache.seatunnel.connectors.seatunnel.jdbc.internal.xa.XaGroupOpsImpl;
@@ -66,6 +68,7 @@ public class JdbcExactlyOnceSinkWriter extends 
AbstractJdbcSinkWriter<Void> {
 
     private final XidGenerator xidGenerator;
 
+    private transient long lastGeneratedTxId = Long.MIN_VALUE;
     private transient Xid currentXid;
     private transient Xid prepareXid;
 
@@ -101,6 +104,24 @@ public class JdbcExactlyOnceSinkWriter extends 
AbstractJdbcSinkWriter<Void> {
         this.xaGroupOps = new XaGroupOpsImpl(xaFacade);
     }
 
+    JdbcExactlyOnceSinkWriter(
+            SinkWriter.Context sinkcontext,
+            JobContext context,
+            List<JdbcSinkState> states,
+            XaFacade xaFacade,
+            XaGroupOps xaGroupOps,
+            XidGenerator xidGenerator,
+            JdbcOutputFormat<SeaTunnelRow, 
JdbcBatchStatementExecutor<SeaTunnelRow>> outputFormat) {
+        this.sinkcontext = sinkcontext;
+        this.context = context;
+        this.recoverStates = states;
+        this.connectionProvider = xaFacade;
+        this.xaFacade = xaFacade;
+        this.xaGroupOps = xaGroupOps;
+        this.xidGenerator = xidGenerator;
+        this.outputFormat = outputFormat;
+    }
+
     private void tryOpen() {
         if (!isOpen) {
             isOpen = true;
@@ -109,11 +130,11 @@ public class JdbcExactlyOnceSinkWriter extends 
AbstractJdbcSinkWriter<Void> {
                 xaFacade.open();
                 outputFormat.open();
                 if (!recoverStates.isEmpty()) {
-                    Xid xid = recoverStates.get(0).getXid();
-                    // Rollback pending transactions that should not include 
recoverStates
-                    xaGroupOps.recoverAndRollback(context, sinkcontext, 
xidGenerator, xid);
+                    Xid excludeXid = recoverStates.get(0).getXid();
+                    // Rollback pending transactions that should not include 
recoverStates.
+                    xaGroupOps.recoverAndRollback(context, sinkcontext, 
xidGenerator, excludeXid);
                 }
-                beginTx();
+                beginTx(System.currentTimeMillis());
             } catch (Exception e) {
                 throw new JdbcConnectorException(
                         CommonErrorCodeDeprecated.WRITER_OPERATION_FAILED,
@@ -143,6 +164,11 @@ public class JdbcExactlyOnceSinkWriter extends 
AbstractJdbcSinkWriter<Void> {
 
     @Override
     public Optional<XidInfo> prepareCommit() throws IOException {
+        return prepareCommit(System.currentTimeMillis());
+    }
+
+    @Override
+    public Optional<XidInfo> prepareCommit(long checkpointId) throws 
IOException {
         tryOpen();
 
         boolean emptyXaTransaction = false;
@@ -157,24 +183,29 @@ public class JdbcExactlyOnceSinkWriter extends 
AbstractJdbcSinkWriter<Void> {
             }
         }
         this.currentXid = null;
-        beginTx();
+        try {
+            beginTx(checkpointId);
+        } catch (Exception e) {
+            if (!emptyXaTransaction) {
+                rollbackPrepareXidOrThrow(e);
+            } else {
+                prepareXid = null;
+            }
+            throw e;
+        }
         checkState(prepareXid != null, "prepare xid must not be null");
         return emptyXaTransaction ? Optional.empty() : Optional.of(new 
XidInfo(prepareXid, 0));
     }
 
     @Override
-    public void abortPrepare() {}
+    public void abortPrepare() {
+        rollbackPrepareXidQuietly();
+        failAndRollbackCurrentXidQuietly();
+    }
 
     @Override
     public void close() throws IOException {
-        if (currentXid != null && xaFacade.isOpen()) {
-            try {
-                LOG.debug("remove current transaction before closing, xid={}", 
currentXid);
-                xaFacade.failAndRollback(currentXid);
-            } catch (Exception e) {
-                LOG.warn("unable to fail/rollback current transaction, 
xid={}", currentXid, e);
-            }
-        }
+        failAndRollbackCurrentXidQuietly();
         try {
             xaFacade.close();
         } catch (Exception e) {
@@ -190,19 +221,109 @@ public class JdbcExactlyOnceSinkWriter extends 
AbstractJdbcSinkWriter<Void> {
         }
     }
 
-    private void beginTx() throws IOException {
+    private void beginTx(long txIdHint) throws IOException {
         checkState(currentXid == null, "currentXid not null");
-        currentXid = xidGenerator.generateXid(context, sinkcontext, 
System.currentTimeMillis());
+        long txId = nextTxId(txIdHint);
+        currentXid = xidGenerator.generateXid(context, sinkcontext, txId);
         try {
             xaFacade.start(currentXid);
         } catch (Exception e) {
+            Xid xid = currentXid;
+            currentXid = null;
             throw new JdbcConnectorException(
                     JdbcConnectorErrorCode.XA_OPERATION_FAILED,
-                    "unable to start xa transaction",
+                    String.format("unable to start xa transaction, xid=%s", 
xid),
                     e);
         }
     }
 
+    private long nextTxId(long txIdHint) {
+        long candidate = txIdHint;
+        if (candidate <= lastGeneratedTxId) {
+            checkState(lastGeneratedTxId != Long.MAX_VALUE, "tx id exhausted");
+            candidate = lastGeneratedTxId + 1;
+        }
+        lastGeneratedTxId = candidate;
+        return candidate;
+    }
+
+    private void rollbackPrepareXidQuietly() {
+        if (prepareXid == null || !xaFacade.isOpen()) {
+            return;
+        }
+        Xid xid = prepareXid;
+        try {
+            LOG.debug("rollback prepared transaction, xid={}", xid);
+            xaFacade.rollback(xid);
+        } catch (Exception e) {
+            LOG.warn("unable to rollback prepared transaction, xid={}", xid, 
e);
+        } finally {
+            prepareXid = null;
+        }
+    }
+
+    private void rollbackPrepareXidOrThrow(Exception beginTxException) {
+        if (prepareXid == null) {
+            return;
+        }
+        Xid xid = prepareXid;
+        if (!xaFacade.isOpen()) {
+            throw new JdbcConnectorException(
+                    JdbcConnectorErrorCode.XA_OPERATION_FAILED,
+                    String.format(
+                            "unable to rollback prepared transaction because 
xaFacade is closed, xid=%s",
+                            xid),
+                    beginTxException);
+        }
+        try {
+            LOG.warn("begin next transaction failed, rollback prepared 
transaction, xid={}", xid);
+            xaFacade.rollback(xid);
+            prepareXid = null;
+        } catch (Exception rollbackException) {
+            JdbcConnectorException rollbackFailure =
+                    new JdbcConnectorException(
+                            JdbcConnectorErrorCode.XA_OPERATION_FAILED,
+                            String.format(
+                                    "failed to rollback prepared transaction 
after begin next transaction failure, xid=%s",
+                                    xid),
+                            rollbackException);
+            rollbackFailure.addSuppressed(beginTxException);
+            tryRecoverPreparedTransactionsAfterRollbackFailure(xid, 
rollbackFailure);
+            throw rollbackFailure;
+        }
+    }
+
+    private void tryRecoverPreparedTransactionsAfterRollbackFailure(
+            Xid failedRollbackXid, JdbcConnectorException rollbackFailure) {
+        try {
+            LOG.warn(
+                    "rollback prepared transaction failed, try to recover 
pending transactions for current subtask, xid={}",
+                    failedRollbackXid);
+            xaGroupOps.recoverAndRollback(context, sinkcontext, xidGenerator, 
null);
+        } catch (Exception recoveryException) {
+            LOG.warn(
+                    "recovery after rollback prepared transaction failure also 
failed, xid={}",
+                    failedRollbackXid,
+                    recoveryException);
+            rollbackFailure.addSuppressed(recoveryException);
+        }
+    }
+
+    private void failAndRollbackCurrentXidQuietly() {
+        if (currentXid == null || !xaFacade.isOpen()) {
+            return;
+        }
+        Xid xid = currentXid;
+        try {
+            LOG.debug("remove current transaction, xid={}", xid);
+            xaFacade.failAndRollback(xid);
+        } catch (Exception e) {
+            LOG.warn("unable to fail/rollback current transaction, xid={}", 
xid, e);
+        } finally {
+            currentXid = null;
+        }
+    }
+
     private void prepareCurrentTx() throws IOException {
         checkState(currentXid != null, "no current xid");
         outputFormat.flush();
diff --git 
a/seatunnel-connectors-v2/connector-jdbc/src/test/java/org/apache/seatunnel/connectors/seatunnel/jdbc/sink/JdbcExactlyOnceSinkWriterTest.java
 
b/seatunnel-connectors-v2/connector-jdbc/src/test/java/org/apache/seatunnel/connectors/seatunnel/jdbc/sink/JdbcExactlyOnceSinkWriterTest.java
new file mode 100644
index 0000000000..2fed44178c
--- /dev/null
+++ 
b/seatunnel-connectors-v2/connector-jdbc/src/test/java/org/apache/seatunnel/connectors/seatunnel/jdbc/sink/JdbcExactlyOnceSinkWriterTest.java
@@ -0,0 +1,332 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.seatunnel.connectors.seatunnel.jdbc.sink;
+
+import org.apache.seatunnel.api.common.JobContext;
+import org.apache.seatunnel.api.sink.DefaultSinkWriterContext;
+import org.apache.seatunnel.api.sink.SinkWriter;
+import org.apache.seatunnel.api.table.type.SeaTunnelRow;
+import 
org.apache.seatunnel.connectors.seatunnel.jdbc.exception.JdbcConnectorException;
+import 
org.apache.seatunnel.connectors.seatunnel.jdbc.internal.JdbcOutputFormat;
+import 
org.apache.seatunnel.connectors.seatunnel.jdbc.internal.executor.JdbcBatchStatementExecutor;
+import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.xa.XaFacade;
+import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.xa.XaGroupOps;
+import org.apache.seatunnel.connectors.seatunnel.jdbc.internal.xa.XidGenerator;
+import org.apache.seatunnel.connectors.seatunnel.jdbc.state.JdbcSinkState;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+
+import javax.transaction.xa.Xid;
+
+import java.lang.reflect.Field;
+import java.util.Collections;
+import java.util.List;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.Mockito.clearInvocations;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+class JdbcExactlyOnceSinkWriterTest {
+
+    @Test
+    void testPrepareCommitWithSameCheckpointGeneratesMonotonicTxIds() throws 
Exception {
+        TestContext context = createWriter();
+
+        context.writer.prepareCommit(100L);
+        context.writer.prepareCommit(100L);
+
+        ArgumentCaptor<Long> txIdCaptor = ArgumentCaptor.forClass(Long.class);
+        verify(context.xidGenerator, times(3)).generateXid(any(), any(), 
txIdCaptor.capture());
+        List<Long> txIds = txIdCaptor.getAllValues();
+        Assertions.assertEquals(3, txIds.size());
+        Assertions.assertTrue(txIds.get(1) > txIds.get(0));
+        Assertions.assertTrue(txIds.get(2) > txIds.get(1));
+    }
+
+    @Test
+    void testPrepareCommitRollbackPreparedXidWhenStartNextTxFailed() throws 
Exception {
+        TestContext context = createWriter();
+
+        doNothing()
+                .doThrow(new RuntimeException("start next tx failed"))
+                .when(context.xaFacade)
+                .start(any());
+
+        Assertions.assertThrows(
+                JdbcConnectorException.class, () -> 
context.writer.prepareCommit(10L));
+
+        ArgumentCaptor<Xid> startXidCaptor = 
ArgumentCaptor.forClass(Xid.class);
+        verify(context.xaFacade, times(2)).start(startXidCaptor.capture());
+        Xid preparedXid = startXidCaptor.getAllValues().get(0);
+        verify(context.xaFacade, times(1)).rollback(preparedXid);
+    }
+
+    @Test
+    void 
testPrepareCommitThrowWhenRollbackPreparedXidFailedAfterBeginNextTxFailed()
+            throws Exception {
+        TestContext context = createWriter();
+
+        doNothing()
+                .doThrow(new RuntimeException("start next tx failed"))
+                .when(context.xaFacade)
+                .start(any());
+        doThrow(new RuntimeException("rollback prepared failed"))
+                .when(context.xaFacade)
+                .rollback(any());
+
+        JdbcConnectorException exception =
+                Assertions.assertThrows(
+                        JdbcConnectorException.class, () -> 
context.writer.prepareCommit(10L));
+
+        Assertions.assertTrue(exception.getMessage().contains("rollback 
prepared transaction"));
+        Assertions.assertEquals(1, exception.getSuppressed().length);
+        Assertions.assertTrue(
+                exception
+                        .getSuppressed()[0]
+                        .getMessage()
+                        .contains("unable to start xa transaction"));
+        ArgumentCaptor<Xid> recoverExcludeXidCaptor = 
ArgumentCaptor.forClass(Xid.class);
+        verify(context.xaGroupOps, times(1))
+                .recoverAndRollback(any(), any(), any(), 
recoverExcludeXidCaptor.capture());
+        Assertions.assertNull(recoverExcludeXidCaptor.getValue());
+    }
+
+    @Test
+    void 
testPrepareCommitAttachRecoveryFailureWhenRollbackAndRecoveryBothFailed()
+            throws Exception {
+        TestContext context = createWriter();
+
+        doNothing()
+                .doThrow(new RuntimeException("start next tx failed"))
+                .when(context.xaFacade)
+                .start(any());
+        doThrow(new RuntimeException("rollback prepared failed"))
+                .when(context.xaFacade)
+                .rollback(any());
+        doThrow(new RuntimeException("recover failed"))
+                .when(context.xaGroupOps)
+                .recoverAndRollback(any(), any(), any(), any());
+
+        JdbcConnectorException exception =
+                Assertions.assertThrows(
+                        JdbcConnectorException.class, () -> 
context.writer.prepareCommit(10L));
+
+        Assertions.assertTrue(exception.getMessage().contains("rollback 
prepared transaction"));
+        Assertions.assertEquals(2, exception.getSuppressed().length);
+        Assertions.assertTrue(
+                exception
+                        .getSuppressed()[0]
+                        .getMessage()
+                        .contains("unable to start xa transaction"));
+        
Assertions.assertTrue(exception.getSuppressed()[1].getMessage().contains("recover
 failed"));
+    }
+
+    @Test
+    void 
testPrepareCommitWithEmptyTransactionDontRollbackPreparedXidWhenStartNextTxFailed()
+            throws Exception {
+        TestContext context = createWriter();
+
+        doThrow(mock(XaFacade.EmptyXaTransactionException.class))
+                .when(context.xaFacade)
+                .endAndPrepare(any());
+        doNothing()
+                .doThrow(new RuntimeException("start next tx failed"))
+                .when(context.xaFacade)
+                .start(any());
+
+        Assertions.assertThrows(
+                JdbcConnectorException.class, () -> 
context.writer.prepareCommit(10L));
+
+        verify(context.xaFacade, never()).rollback(any());
+        Assertions.assertNull(getPrivateField(context.writer, "prepareXid"));
+    }
+
+    @Test
+    void testInjectedConstructorOpenXidGeneratorOnFirstUse() throws Exception {
+        TestContext context = createWriter();
+
+        verify(context.xidGenerator, never()).open();
+
+        context.writer.prepareCommit(10L);
+
+        verify(context.xidGenerator, times(1)).open();
+    }
+
+    @Test
+    void testTryOpenSkipRecoverAndRollbackWhenRecoverStateIsEmpty() throws 
Exception {
+        TestContext context = createWriter();
+
+        context.writer.prepareCommit(10L);
+
+        verify(context.xaGroupOps, never()).recoverAndRollback(any(), any(), 
any(), any());
+    }
+
+    @Test
+    void testTryOpenRecoverAndRollbackWhenRecoverStatePresent() throws 
Exception {
+        Xid recoveredStateXid = new TestXid(10L);
+        TestContext context =
+                createWriter(Collections.singletonList(new 
JdbcSinkState(recoveredStateXid)));
+
+        context.writer.prepareCommit(10L);
+
+        ArgumentCaptor<Xid> excludeXidCaptor = 
ArgumentCaptor.forClass(Xid.class);
+        verify(context.xaGroupOps, times(1))
+                .recoverAndRollback(any(), any(), any(), 
excludeXidCaptor.capture());
+        Assertions.assertSame(recoveredStateXid, excludeXidCaptor.getValue());
+    }
+
+    @Test
+    void testAbortPrepareRollbackPreparedAndCurrentTransaction() throws 
Exception {
+        TestContext context = createWriter();
+
+        Xid preparedXid = new TestXid(1L);
+        Xid currentXid = new TestXid(2L);
+        setPrivateField(context.writer, "prepareXid", preparedXid);
+        setPrivateField(context.writer, "currentXid", currentXid);
+
+        context.writer.abortPrepare();
+        verify(context.xaFacade, times(1)).rollback(preparedXid);
+        verify(context.xaFacade, times(1)).failAndRollback(currentXid);
+        Assertions.assertNull(getPrivateField(context.writer, "prepareXid"));
+        Assertions.assertNull(getPrivateField(context.writer, "currentXid"));
+
+        clearInvocations(context.xaFacade);
+        context.writer.abortPrepare();
+        verify(context.xaFacade, never()).rollback(any());
+        verify(context.xaFacade, never()).failAndRollback(any());
+    }
+
+    @Test
+    void testCloseRollbackCurrentTransactionOnly() throws Exception {
+        TestContext context = createWriter();
+
+        Xid preparedXid = new TestXid(3L);
+        Xid currentXid = new TestXid(4L);
+        setPrivateField(context.writer, "prepareXid", preparedXid);
+        setPrivateField(context.writer, "currentXid", currentXid);
+
+        context.writer.close();
+
+        verify(context.xaFacade, never()).rollback(any());
+        verify(context.xaFacade, times(1)).failAndRollback(currentXid);
+        verify(context.xaFacade, times(1)).close();
+        verify(context.outputFormat, times(1)).close();
+        verify(context.xidGenerator, times(1)).close();
+        Assertions.assertNull(getPrivateField(context.writer, "prepareXid"));
+        Assertions.assertNull(getPrivateField(context.writer, "currentXid"));
+    }
+
+    private TestContext createWriter() throws Exception {
+        return createWriter(Collections.<JdbcSinkState>emptyList());
+    }
+
+    private TestContext createWriter(List<JdbcSinkState> states) throws 
Exception {
+        SinkWriter.Context sinkWriterContext = new DefaultSinkWriterContext(0, 
1);
+        JobContext jobContext = new JobContext(1L);
+        XaFacade xaFacade = mock(XaFacade.class);
+        XaGroupOps xaGroupOps = mock(XaGroupOps.class);
+        XidGenerator xidGenerator = mock(XidGenerator.class);
+        JdbcOutputFormat<SeaTunnelRow, 
JdbcBatchStatementExecutor<SeaTunnelRow>> outputFormat =
+                mock(JdbcOutputFormat.class);
+
+        when(xaFacade.isOpen()).thenReturn(true);
+        when(xidGenerator.generateXid(any(), any(), anyLong()))
+                .thenAnswer(invocation -> new TestXid((Long) 
invocation.getArguments()[2]));
+
+        JdbcExactlyOnceSinkWriter writer =
+                new JdbcExactlyOnceSinkWriter(
+                        sinkWriterContext,
+                        jobContext,
+                        states,
+                        xaFacade,
+                        xaGroupOps,
+                        xidGenerator,
+                        outputFormat);
+        return new TestContext(writer, xaFacade, xaGroupOps, xidGenerator, 
outputFormat);
+    }
+
+    private static void setPrivateField(Object target, String fieldName, 
Object value)
+            throws Exception {
+        Field field = 
JdbcExactlyOnceSinkWriter.class.getDeclaredField(fieldName);
+        field.setAccessible(true);
+        field.set(target, value);
+    }
+
+    private static Object getPrivateField(Object target, String fieldName) 
throws Exception {
+        Field field = 
JdbcExactlyOnceSinkWriter.class.getDeclaredField(fieldName);
+        field.setAccessible(true);
+        return field.get(target);
+    }
+
+    private static class TestContext {
+        private final JdbcExactlyOnceSinkWriter writer;
+        private final XaFacade xaFacade;
+        private final XaGroupOps xaGroupOps;
+        private final XidGenerator xidGenerator;
+        private final JdbcOutputFormat<SeaTunnelRow, 
JdbcBatchStatementExecutor<SeaTunnelRow>>
+                outputFormat;
+
+        private TestContext(
+                JdbcExactlyOnceSinkWriter writer,
+                XaFacade xaFacade,
+                XaGroupOps xaGroupOps,
+                XidGenerator xidGenerator,
+                JdbcOutputFormat<SeaTunnelRow, 
JdbcBatchStatementExecutor<SeaTunnelRow>>
+                        outputFormat) {
+            this.writer = writer;
+            this.xaFacade = xaFacade;
+            this.xaGroupOps = xaGroupOps;
+            this.xidGenerator = xidGenerator;
+            this.outputFormat = outputFormat;
+        }
+    }
+
+    private static class TestXid implements Xid {
+        private final long txId;
+
+        private TestXid(long txId) {
+            this.txId = txId;
+        }
+
+        @Override
+        public int getFormatId() {
+            return 201;
+        }
+
+        @Override
+        public byte[] getGlobalTransactionId() {
+            return new byte[] {
+                (byte) txId, (byte) (txId >>> 8), (byte) (txId >>> 16), (byte) 
(txId >>> 24)
+            };
+        }
+
+        @Override
+        public byte[] getBranchQualifier() {
+            return new byte[] {0, 0, 0, 1};
+        }
+    }
+}

Reply via email to