This is an automated email from the ASF dual-hosted git repository.

jiabaosun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/flink-connector-mongodb.git


The following commit(s) were added to refs/heads/main by this push:
     new a2f1083  [FLINK-35616][Connectors/MongoDB] Support upsert into sharded 
collections (#37)
a2f1083 is described below

commit a2f1083c2b0020cde626681e4ebcd0bec649547c
Author: Jiabao Sun <[email protected]>
AuthorDate: Tue Jul 16 09:57:01 2024 +0800

    [FLINK-35616][Connectors/MongoDB] Support upsert into sharded collections 
(#37)
---
 docs/content.zh/docs/connectors/table/mongodb.md   |  58 +++++
 docs/content/docs/connectors/table/mongodb.md      |  62 +++++
 .../mongodb/common/utils/MongoValidationUtils.java |   4 +-
 .../mongodb/table/MongoDynamicTableFactory.java    |  15 +-
 .../mongodb/table/MongoDynamicTableSink.java       |  64 +++--
 ...xtractor.java => MongoPrimaryKeyExtractor.java} |  16 +-
 .../mongodb/table/MongoShardKeysExtractor.java     | 121 +++++++++
 .../MongoRowDataSerializationSchema.java           |  23 +-
 .../mongodb/source/MongoSourceITCase.java          |  60 ++---
 .../table/MongoDynamicTableFactoryTest.java        |  19 +-
 .../table/MongoPartitionedTableSinkITCase.java     | 290 +++++++++++++++++++++
 ...Test.java => MongoPrimaryKeyExtractorTest.java} |  23 +-
 .../mongodb/table/MongoShardKeysExtractorTest.java | 117 +++++++++
 .../connector/mongodb/testutils/MongoTestUtil.java |  44 ++++
 14 files changed, 805 insertions(+), 111 deletions(-)

diff --git a/docs/content.zh/docs/connectors/table/mongodb.md 
b/docs/content.zh/docs/connectors/table/mongodb.md
index c111481..dde8363 100644
--- a/docs/content.zh/docs/connectors/table/mongodb.md
+++ b/docs/content.zh/docs/connectors/table/mongodb.md
@@ -340,6 +340,64 @@ lookup cache 的主要目的是用于提高时态表关联 MongoDB 连接器的
 如果出现故障,Flink 作业会从上次成功的 checkpoint 恢复并重新处理,这可能导致在恢复过程中重复处理消息。
 强烈推荐使用 upsert 模式,因为如果需要重复处理记录,它有助于避免违反数据库主键约束和产生重复数据。
 
+### [Upsert 
写入分片集合](https://www.mongodb.com/docs/manual/reference/method/db.collection.updateOne/#upsert-on-a-sharded-collection)
+
+在 Mongo 文档中提到:
+> To use db.collection.updateOne() on a sharded collection:
+>
+> - If you don't specify upsert: true, you must include an exact match on the 
_id field or target a single shard (such as by including the shard key in the 
filter).
+> - If you specify upsert: true, the filter must include the shard key.
+>
+> However, documents in a sharded collection can be missing the shard key 
fields.
+> To target a document that is missing the shard key, you can use the null 
equality match
+> in conjunction with another filter condition (such as on the _id field).
+
+当使用 upsert 模式写入分片集合时,需要将分片键的值添加到 filter 中, 如:
+```javascript
+db.collection.updateOne(
+    {
+        _id: ObjectId('<value>'),
+        shardKey0: '<value>',
+        shardKey1: '<value>'
+    },
+    { $set: { status: "D" }},
+    { upsert: true }
+);
+```
+
+使用 Flink SQL 创建 sink 表映射分片集合时,需要使用 `PARTITIONED BY` 语法声明分片键。
+分片键的值将在运行时从每个单独的记录中获取,并将其添加到 filter 中。
+
+```sql
+CREATE TABLE MySinkTable (
+    _id       BIGINT,
+    shardKey0 STRING,
+    shardKey1 STRING,
+    status    STRING,
+    PRIMARY KEY (_id) NOT ENFORCED
+) PARTITIONED BY (shardKey0, shardKey1) WITH (
+    'connector' = 'mongodb',
+    'uri' = 'mongodb://user:[email protected]:27017',
+    'database' = 'my_db',
+    'collection' = 'users'
+);
+
+-- 动态写入分片集合
+INSERT INTO MySinkTable SELECT _id, shardKey0, shardKey1, status FROM T;
+
+-- 指定固定分片键的值
+INSERT INTO MySinkTable PARTITION(shardKey0 = 'value0', shardKey1 = 'value1') 
SELECT 1, 'INIT';
+
+-- 指定固定分片键值 (shardKey0) 和动态分片键值 (shardKey1) 
+INSERT INTO MySinkTable PARTITION(shardKey0 = 'value0') SELECT 1, 'value1' 
'INIT';
+```
+{{< hint warning >}}
+限制:尽管 MongoDB 4.2 及之后版本中分片键的值不再是不可变的,
+使用 MongoDB Connector upsert 写入分片集合需要确保分片键的值保持不可变。
+因为在 upsert 模式下,只能获取更新后的分片键值,无法获取原始分片键的值添加至 filter 中,
+这可能导致重复记录的错误。
+{{< /hint >}}
+
 ### 过滤器下推
 
 MongoDB 支持将 Flink SQL 的简单比较和逻辑过滤器下推以优化查询。
diff --git a/docs/content/docs/connectors/table/mongodb.md 
b/docs/content/docs/connectors/table/mongodb.md
index 7ce3bdf..340a925 100644
--- a/docs/content/docs/connectors/table/mongodb.md
+++ b/docs/content/docs/connectors/table/mongodb.md
@@ -371,6 +371,68 @@ If there are failures, the Flink job will recover and 
re-process from last succe
 which can lead to re-processing messages during recovery. The upsert mode is 
highly recommended as 
 it helps avoid constraint violations or duplicate data if records need to be 
re-processed.
 
+### [Upsert on a sharded 
collection](https://www.mongodb.com/docs/manual/reference/method/db.collection.updateOne/#upsert-on-a-sharded-collection)
+
+As Mongo Reference says:
+> To use db.collection.updateOne() on a sharded collection:
+>
+> - If you don't specify upsert: true, you must include an exact match on the 
_id field or target a single shard (such as by including the shard key in the 
filter).
+> - If you specify upsert: true, the filter must include the shard key.
+>
+> However, documents in a sharded collection can be missing the shard key 
fields. 
+> To target a document that is missing the shard key, you can use the null 
equality match 
+> in conjunction with another filter condition (such as on the _id field).
+
+When upsert into a sharded collection, the value of the shard key needs to be 
added to the filter. 
+For example:
+```javascript
+db.collection.updateOne(
+    {
+        _id: ObjectId('<value>'),
+        shardKey0: '<value>',
+        shardKey1: '<value>'
+    },
+    { $set: { status: "D" }},
+    { upsert: true }
+);
+```
+
+In Flink SQL, when creating a sink table, the shard keys need to be declared 
using the `PARTITIONED BY` syntax. 
+The values for shard keys will be obtained from each individual record during 
runtime and added them into the filter.
+
+```sql
+CREATE TABLE MySinkTable (
+    _id       BIGINT,
+    shardKey0 STRING,
+    shardKey1 STRING,
+    status    STRING,
+    PRIMARY KEY (_id) NOT ENFORCED
+) PARTITIONED BY (shardKey0, shardKey1) WITH (
+    'connector' = 'mongodb',
+    'uri' = 'mongodb://user:[email protected]:27017',
+    'database' = 'my_db',
+    'collection' = 'users'
+);
+
+-- Insert with dynamic partition
+INSERT INTO MySinkTable SELECT _id, shardKey0, shardKey1, status FROM T;
+
+-- Insert with static partition
+INSERT INTO MySinkTable PARTITION(shardKey0 = 'value0', shardKey1 = 'value1') 
SELECT 1, 'INIT';
+
+-- Insert with static(shardKey0) and dynamic(shardKey1) partition
+INSERT INTO MySinkTable PARTITION(shardKey0 = 'value0') SELECT 1, 'value1' 
'INIT';
+```
+{{< hint warning >}}
+LIMITATION: Although the shard key value is no longer immutable in MongoDB 4.2 
and later,
+it is necessary to ensure that the shard key remains immutable.
+
+Using Flink SQL upsert mode to write to a sharded collection, 
+only the updated shard key value can be obtained and 
+the original shard key value cannot be provided in the filter
+which may cause a duplicate record error.
+{{< /hint >}}
+
 ### Filters Pushdown
 
 MongoDB supports pushing down simple comparisons and logical filters to 
optimize queries.
diff --git 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/common/utils/MongoValidationUtils.java
 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/common/utils/MongoValidationUtils.java
index 40f96fd..809e984 100644
--- 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/common/utils/MongoValidationUtils.java
+++ 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/common/utils/MongoValidationUtils.java
@@ -18,7 +18,7 @@
 package org.apache.flink.connector.mongodb.common.utils;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.connector.mongodb.table.MongoKeyExtractor;
+import org.apache.flink.connector.mongodb.table.MongoPrimaryKeyExtractor;
 import 
org.apache.flink.connector.mongodb.table.converter.RowDataToBsonConverters;
 import org.apache.flink.table.api.ValidationException;
 import org.apache.flink.table.types.DataType;
@@ -93,7 +93,7 @@ public class MongoValidationUtils {
      *   <li>Starting in version 4.2, MongoDB removes the Index Key Limit.
      * </ul>
      *
-     * <p>As of now it is extracted by {@link MongoKeyExtractor} according to 
the primary key
+     * <p>As of now it is extracted by {@link MongoPrimaryKeyExtractor} 
according to the primary key
      * specified by the Flink table schema.
      *
      * <ul>
diff --git 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableFactory.java
 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableFactory.java
index 1430f23..59b8ba5 100644
--- 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableFactory.java
+++ 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableFactory.java
@@ -24,19 +24,14 @@ import 
org.apache.flink.connector.mongodb.common.config.MongoConnectionOptions;
 import org.apache.flink.connector.mongodb.sink.config.MongoWriteOptions;
 import org.apache.flink.connector.mongodb.source.config.MongoReadOptions;
 import org.apache.flink.connector.mongodb.table.config.MongoConfiguration;
-import org.apache.flink.table.catalog.ResolvedSchema;
 import org.apache.flink.table.connector.sink.DynamicTableSink;
 import org.apache.flink.table.connector.source.DynamicTableSource;
 import org.apache.flink.table.connector.source.lookup.LookupOptions;
 import org.apache.flink.table.connector.source.lookup.cache.DefaultLookupCache;
 import org.apache.flink.table.connector.source.lookup.cache.LookupCache;
-import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.factories.DynamicTableSinkFactory;
 import org.apache.flink.table.factories.DynamicTableSourceFactory;
 import org.apache.flink.table.factories.FactoryUtil;
-import org.apache.flink.util.function.SerializableFunction;
-
-import org.bson.BsonValue;
 
 import javax.annotation.Nullable;
 
@@ -148,18 +143,12 @@ public class MongoDynamicTableFactory
         MongoConfiguration config = new 
MongoConfiguration(helper.getOptions());
         helper.validate();
 
-        ResolvedSchema schema = context.getCatalogTable().getResolvedSchema();
-        boolean isUpsert = schema.getPrimaryKey().isPresent();
-        SerializableFunction<RowData, BsonValue> keyExtractor =
-                MongoKeyExtractor.createKeyExtractor(schema);
-
         return new MongoDynamicTableSink(
                 getConnectionOptions(config),
                 getWriteOptions(config),
                 config.getSinkParallelism(),
-                isUpsert,
-                context.getPhysicalRowDataType(),
-                keyExtractor);
+                context.getCatalogTable().getResolvedSchema(),
+                context.getCatalogTable().getPartitionKeys().toArray(new 
String[0]));
     }
 
     @Nullable
diff --git 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableSink.java
 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableSink.java
index 335080a..99bdc7f 100644
--- 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableSink.java
+++ 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableSink.java
@@ -24,46 +24,59 @@ import 
org.apache.flink.connector.mongodb.sink.config.MongoWriteOptions;
 import 
org.apache.flink.connector.mongodb.table.converter.RowDataToBsonConverters;
 import 
org.apache.flink.connector.mongodb.table.converter.RowDataToBsonConverters.RowDataToBsonConverter;
 import 
org.apache.flink.connector.mongodb.table.serialization.MongoRowDataSerializationSchema;
+import org.apache.flink.table.catalog.ResolvedSchema;
 import org.apache.flink.table.connector.ChangelogMode;
 import org.apache.flink.table.connector.sink.DynamicTableSink;
 import org.apache.flink.table.connector.sink.SinkV2Provider;
+import org.apache.flink.table.connector.sink.abilities.SupportsPartitioning;
 import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.types.DataType;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.util.function.SerializableFunction;
 
+import org.bson.BsonDocument;
 import org.bson.BsonValue;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
 
+import java.util.Arrays;
+import java.util.Map;
 import java.util.Objects;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /** A {@link DynamicTableSink} for MongoDB. */
 @Internal
-public class MongoDynamicTableSink implements DynamicTableSink {
+public class MongoDynamicTableSink implements DynamicTableSink, 
SupportsPartitioning {
+
+    private static final Logger LOG = 
LoggerFactory.getLogger(MongoDynamicTableSink.class);
 
     private final MongoConnectionOptions connectionOptions;
     private final MongoWriteOptions writeOptions;
     @Nullable private final Integer parallelism;
     private final boolean isUpsert;
-    private final DataType physicalRowDataType;
-    private final SerializableFunction<RowData, BsonValue> keyExtractor;
+    private final ResolvedSchema resolvedSchema;
+    private final String[] partitionKeys;
+    private final SerializableFunction<RowData, BsonValue> primaryKeyExtractor;
+    private final SerializableFunction<RowData, BsonDocument> 
shardKeysExtractor;
 
     public MongoDynamicTableSink(
             MongoConnectionOptions connectionOptions,
             MongoWriteOptions writeOptions,
             @Nullable Integer parallelism,
-            boolean isUpsert,
-            DataType physicalRowDataType,
-            SerializableFunction<RowData, BsonValue> keyExtractor) {
+            ResolvedSchema resolvedSchema,
+            String[] partitionKeys) {
         this.connectionOptions = checkNotNull(connectionOptions);
         this.writeOptions = checkNotNull(writeOptions);
         this.parallelism = parallelism;
-        this.isUpsert = isUpsert;
-        this.physicalRowDataType = checkNotNull(physicalRowDataType);
-        this.keyExtractor = checkNotNull(keyExtractor);
+        this.resolvedSchema = checkNotNull(resolvedSchema);
+        this.partitionKeys = checkNotNull(partitionKeys);
+        this.isUpsert = resolvedSchema.getPrimaryKey().isPresent();
+        this.primaryKeyExtractor =
+                
MongoPrimaryKeyExtractor.createPrimaryKeyExtractor(resolvedSchema);
+        this.shardKeysExtractor =
+                
MongoShardKeysExtractor.createShardKeysExtractor(resolvedSchema, partitionKeys);
     }
 
     @Override
@@ -79,10 +92,11 @@ public class MongoDynamicTableSink implements 
DynamicTableSink {
     public SinkRuntimeProvider getSinkRuntimeProvider(Context context) {
         final RowDataToBsonConverter rowDataToBsonConverter =
                 RowDataToBsonConverters.createConverter(
-                        (RowType) physicalRowDataType.getLogicalType());
+                        (RowType) 
resolvedSchema.toPhysicalRowDataType().getLogicalType());
 
         final MongoRowDataSerializationSchema serializationSchema =
-                new MongoRowDataSerializationSchema(rowDataToBsonConverter, 
keyExtractor);
+                new MongoRowDataSerializationSchema(
+                        rowDataToBsonConverter, primaryKeyExtractor, 
shardKeysExtractor);
 
         final MongoSink<RowData> mongoSink =
                 MongoSink.<RowData>builder()
@@ -99,15 +113,16 @@ public class MongoDynamicTableSink implements 
DynamicTableSink {
         return SinkV2Provider.of(mongoSink, parallelism);
     }
 
+    @Override
+    public void applyStaticPartition(Map<String, String> partition) {
+        // The value of the partition keys is obtained at runtime, just print 
static partition here.
+        LOG.info("Applied static partition: {}", partition);
+    }
+
     @Override
     public MongoDynamicTableSink copy() {
         return new MongoDynamicTableSink(
-                connectionOptions,
-                writeOptions,
-                parallelism,
-                isUpsert,
-                physicalRowDataType,
-                keyExtractor);
+                connectionOptions, writeOptions, parallelism, resolvedSchema, 
partitionKeys);
     }
 
     @Override
@@ -128,12 +143,19 @@ public class MongoDynamicTableSink implements 
DynamicTableSink {
                 && Objects.equals(writeOptions, that.writeOptions)
                 && Objects.equals(parallelism, that.parallelism)
                 && Objects.equals(isUpsert, that.isUpsert)
-                && Objects.equals(physicalRowDataType, 
that.physicalRowDataType);
+                && Objects.equals(resolvedSchema, that.resolvedSchema)
+                && Arrays.equals(partitionKeys, that.partitionKeys);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(
-                connectionOptions, writeOptions, parallelism, isUpsert, 
physicalRowDataType);
+        return 31
+                        * Objects.hash(
+                                connectionOptions,
+                                writeOptions,
+                                parallelism,
+                                isUpsert,
+                                resolvedSchema)
+                + Arrays.hashCode(partitionKeys);
     }
 }
diff --git 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoKeyExtractor.java
 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoPrimaryKeyExtractor.java
similarity index 91%
rename from 
flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoKeyExtractor.java
rename to 
flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoPrimaryKeyExtractor.java
index 6eba160..a82a1d1 100644
--- 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoKeyExtractor.java
+++ 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoPrimaryKeyExtractor.java
@@ -40,16 +40,16 @@ import java.util.Optional;
 import static 
org.apache.flink.connector.mongodb.common.utils.MongoConstants.ID_FIELD;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
-/** An extractor for a MongoDB key from a {@link RowData}. */
+/** An extractor for a MongoDB primary key from a {@link RowData}. */
 @Internal
-public class MongoKeyExtractor implements SerializableFunction<RowData, 
BsonValue> {
+public class MongoPrimaryKeyExtractor implements SerializableFunction<RowData, 
BsonValue> {
 
     private static final long serialVersionUID = 1L;
 
     public static final String RESERVED_ID = ID_FIELD;
 
-    private static final AppendOnlyKeyExtractor APPEND_ONLY_KEY_EXTRACTOR =
-            new AppendOnlyKeyExtractor();
+    private static final AppendOnlyPrimaryKeyExtractor 
APPEND_ONLY_KEY_EXTRACTOR =
+            new AppendOnlyPrimaryKeyExtractor();
 
     private final int[] primaryKeyIndexes;
 
@@ -57,7 +57,7 @@ public class MongoKeyExtractor implements 
SerializableFunction<RowData, BsonValu
 
     private final FieldGetter primaryKeyGetter;
 
-    private MongoKeyExtractor(LogicalType primaryKeyType, int[] 
primaryKeyIndexes) {
+    private MongoPrimaryKeyExtractor(LogicalType primaryKeyType, int[] 
primaryKeyIndexes) {
         this.primaryKeyIndexes = primaryKeyIndexes;
         this.primaryKeyConverter = 
RowDataToBsonConverters.createFieldDataConverter(primaryKeyType);
 
@@ -85,7 +85,7 @@ public class MongoKeyExtractor implements 
SerializableFunction<RowData, BsonValu
         return keyValue;
     }
 
-    public static SerializableFunction<RowData, BsonValue> createKeyExtractor(
+    public static SerializableFunction<RowData, BsonValue> 
createPrimaryKeyExtractor(
             ResolvedSchema resolvedSchema) {
 
         Optional<UniqueConstraint> primaryKey = resolvedSchema.getPrimaryKey();
@@ -125,7 +125,7 @@ public class MongoKeyExtractor implements 
SerializableFunction<RowData, BsonValu
 
         MongoValidationUtils.validatePrimaryKey(primaryKeyType);
 
-        return new MongoKeyExtractor(primaryKeyType.getLogicalType(), 
primaryKeyIndexes);
+        return new MongoPrimaryKeyExtractor(primaryKeyType.getLogicalType(), 
primaryKeyIndexes);
     }
 
     private static boolean isCompoundPrimaryKey(int[] primaryKeyIndexes) {
@@ -141,7 +141,7 @@ public class MongoKeyExtractor implements 
SerializableFunction<RowData, BsonValu
      * use static class instead of lambda because the maven shade plugin 
cannot relocate classes in
      * SerializedLambdas (MSHADE-260).
      */
-    private static class AppendOnlyKeyExtractor
+    private static class AppendOnlyPrimaryKeyExtractor
             implements SerializableFunction<RowData, BsonValue> {
         private static final long serialVersionUID = 1L;
 
diff --git 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoShardKeysExtractor.java
 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoShardKeysExtractor.java
new file mode 100644
index 0000000..006c83b
--- /dev/null
+++ 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/MongoShardKeysExtractor.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.connector.mongodb.table;
+
+import org.apache.flink.annotation.Internal;
+import 
org.apache.flink.connector.mongodb.table.converter.RowDataToBsonConverters;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.connector.Projection;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.utils.ProjectedRowData;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.util.function.SerializableFunction;
+
+import org.bson.BsonDocument;
+import org.bson.BsonObjectId;
+import org.bson.BsonValue;
+import org.bson.types.ObjectId;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+
+/** An extractor for a MongoDB shard keys from a {@link RowData}. */
+@Internal
+public class MongoShardKeysExtractor implements SerializableFunction<RowData, 
BsonDocument> {
+
+    private static final long serialVersionUID = 1L;
+
+    private static final Logger LOG = 
LoggerFactory.getLogger(MongoShardKeysExtractor.class);
+
+    private static final BsonDocument EMPTY_DOCUMENT = new BsonDocument();
+
+    private final SerializableFunction<Object, BsonValue> shardKeysConverter;
+
+    private final RowData.FieldGetter shardKeysGetter;
+
+    private MongoShardKeysExtractor(LogicalType shardKeysType, int[] 
shardKeysIndexes) {
+        this.shardKeysConverter = 
RowDataToBsonConverters.createFieldDataConverter(shardKeysType);
+        this.shardKeysGetter =
+                rowData -> 
ProjectedRowData.from(shardKeysIndexes).replaceRow(rowData);
+    }
+
+    @Override
+    public BsonDocument apply(RowData rowData) {
+        BsonDocument shardKeysDoc =
+                Optional.ofNullable(shardKeysGetter.getFieldOrNull(rowData))
+                        .map(shardKeys -> 
shardKeysConverter.apply(shardKeys).asDocument())
+                        .orElse(EMPTY_DOCUMENT);
+
+        shardKeysDoc
+                .entrySet()
+                .forEach(
+                        entry -> {
+                            if (entry.getValue().isString()) {
+                                String keyString = 
entry.getValue().asString().getValue();
+                                // Try to restore MongoDB's ObjectId from 
string.
+                                if (ObjectId.isValid(keyString)) {
+                                    entry.setValue(new BsonObjectId(new 
ObjectId(keyString)));
+                                }
+                            }
+                        });
+
+        return shardKeysDoc;
+    }
+
+    public static SerializableFunction<RowData, BsonDocument> 
createShardKeysExtractor(
+            ResolvedSchema resolvedSchema, String[] shardKeys) {
+        // no shard keys are declared.
+        if (shardKeys.length == 0) {
+            return new NoOpShardKeysExtractor();
+        }
+
+        int[] shardKeysIndexes = 
getShardKeysIndexes(resolvedSchema.getColumnNames(), shardKeys);
+        DataType physicalRowDataType = resolvedSchema.toPhysicalRowDataType();
+        DataType shardKeysType = 
Projection.of(shardKeysIndexes).project(physicalRowDataType);
+
+        MongoShardKeysExtractor shardKeysExtractor =
+                new MongoShardKeysExtractor(shardKeysType.getLogicalType(), 
shardKeysIndexes);
+
+        LOG.info("Shard keys extractor created, shard keys: {}", 
Arrays.toString(shardKeys));
+        return shardKeysExtractor;
+    }
+
+    private static int[] getShardKeysIndexes(List<String> columnNames, 
String[] shardKeys) {
+        return 
Arrays.stream(shardKeys).mapToInt(columnNames::indexOf).toArray();
+    }
+
+    /**
+     * It behaves as no-op extractor when no shard keys are declared. We use 
static class instead of
+     * lambda because the maven shade plugin cannot relocate classes in 
SerializedLambdas
+     * (MSHADE-260).
+     */
+    private static class NoOpShardKeysExtractor
+            implements SerializableFunction<RowData, BsonDocument> {
+
+        private static final long serialVersionUID = 1L;
+
+        @Override
+        public BsonDocument apply(RowData rowData) {
+            return EMPTY_DOCUMENT;
+        }
+    }
+}
diff --git 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/serialization/MongoRowDataSerializationSchema.java
 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/serialization/MongoRowDataSerializationSchema.java
index 616da64..afd0271 100644
--- 
a/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/serialization/MongoRowDataSerializationSchema.java
+++ 
b/flink-connector-mongodb/src/main/java/org/apache/flink/connector/mongodb/table/serialization/MongoRowDataSerializationSchema.java
@@ -39,13 +39,16 @@ import java.util.function.Function;
 public class MongoRowDataSerializationSchema implements 
MongoSerializationSchema<RowData> {
 
     private final RowDataToBsonConverters.RowDataToBsonConverter 
rowDataToBsonConverter;
-    private final Function<RowData, BsonValue> createKey;
+    private final Function<RowData, BsonValue> primaryKeyExtractor;
+    private final Function<RowData, BsonDocument> shardKeysExtractor;
 
     public MongoRowDataSerializationSchema(
             RowDataToBsonConverters.RowDataToBsonConverter 
rowDataToBsonConverter,
-            Function<RowData, BsonValue> createKey) {
+            Function<RowData, BsonValue> primaryKeyExtractor,
+            Function<RowData, BsonDocument> shardKeysExtractor) {
         this.rowDataToBsonConverter = rowDataToBsonConverter;
-        this.createKey = createKey;
+        this.primaryKeyExtractor = primaryKeyExtractor;
+        this.shardKeysExtractor = shardKeysExtractor;
     }
 
     @Override
@@ -64,10 +67,18 @@ public class MongoRowDataSerializationSchema implements 
MongoSerializationSchema
 
     private WriteModel<BsonDocument> processUpsert(RowData row) {
         final BsonDocument document = rowDataToBsonConverter.convert(row);
-        final BsonValue key = createKey.apply(row);
+        final BsonValue key = primaryKeyExtractor.apply(row);
         if (key != null) {
             BsonDocument filter = new BsonDocument("_id", key);
-            // _id is immutable so we remove it here to prevent exception.
+
+            // For upsert operation on a sharded collection, the full sharded 
key must be included
+            // in the filter.
+            BsonDocument shardKeysFilter = shardKeysExtractor.apply(row);
+            if (!shardKeysFilter.isEmpty()) {
+                filter.putAll(shardKeysFilter);
+            }
+
+            // _id is immutable, so we remove it here to prevent exception.
             document.remove("_id");
             BsonDocument update = new BsonDocument("$set", document);
             return new UpdateOneModel<>(filter, update, new 
UpdateOptions().upsert(true));
@@ -77,7 +88,7 @@ public class MongoRowDataSerializationSchema implements 
MongoSerializationSchema
     }
 
     private WriteModel<BsonDocument> processDelete(RowData row) {
-        final BsonValue key = createKey.apply(row);
+        final BsonValue key = primaryKeyExtractor.apply(row);
         BsonDocument filter = new BsonDocument("_id", key);
         return new DeleteOneModel<>(filter);
     }
diff --git 
a/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/source/MongoSourceITCase.java
 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/source/MongoSourceITCase.java
index 57a49ff..3f1df50 100644
--- 
a/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/source/MongoSourceITCase.java
+++ 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/source/MongoSourceITCase.java
@@ -44,7 +44,6 @@ import org.apache.flink.util.CollectionUtil;
 import com.mongodb.client.MongoClient;
 import com.mongodb.client.MongoClients;
 import com.mongodb.client.MongoCollection;
-import com.mongodb.client.MongoDatabase;
 import com.mongodb.client.model.Filters;
 import com.mongodb.client.model.IndexOptions;
 import com.mongodb.client.model.UpdateOptions;
@@ -54,6 +53,7 @@ import org.bson.BsonDocument;
 import org.bson.BsonInt32;
 import org.bson.BsonString;
 import org.bson.Document;
+import org.bson.conversions.Bson;
 import org.junit.jupiter.api.AfterAll;
 import org.junit.jupiter.api.BeforeAll;
 import org.junit.jupiter.api.Test;
@@ -72,6 +72,10 @@ import java.util.stream.Stream;
 
 import static 
org.apache.flink.connector.mongodb.common.utils.MongoConstants.DEFAULT_JSON_WRITER_SETTINGS;
 import static 
org.apache.flink.connector.mongodb.common.utils.MongoConstants.ID_FIELD;
+import static 
org.apache.flink.connector.mongodb.testutils.MongoTestUtil.CHUNK_SIZE_FIELD;
+import static 
org.apache.flink.connector.mongodb.testutils.MongoTestUtil.CONFIG_DATABASE;
+import static 
org.apache.flink.connector.mongodb.testutils.MongoTestUtil.SETTINGS_COLLECTION;
+import static 
org.apache.flink.connector.mongodb.testutils.MongoTestUtil.VALUE_FIELD;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** IT cases for using Mongo Source. */
@@ -95,12 +99,6 @@ class MongoSourceITCase {
 
     private static MongoClient mongoClient;
 
-    private static final String ADMIN_DATABASE = "admin";
-    private static final String CONFIG_DATABASE = "config";
-    private static final String SETTINGS_COLLECTION = "settings";
-    private static final String CHUNK_SIZE_FIELD = "chunksize";
-    private static final String VALUE_FIELD = "value";
-
     private static final String TEST_DATABASE = "test_source";
     private static final String TEST_COLLECTION = "test_coll";
     private static final String TEST_SHARDED_COLLECTION = "test_sharded_coll";
@@ -132,38 +130,26 @@ class MongoSourceITCase {
         initTestData(TEST_HASHED_KEY_SHARDED_COLLECTION);
 
         // create unique index {f0: 1, f1: 1}.
-        mongoClient
-                .getDatabase(TEST_DATABASE)
-                .getCollection(TEST_SHARDED_COLLECTION)
-                .createIndex(
-                        BsonDocument.parse("{ f0: 1, f1: 1 }"), new 
IndexOptions().unique(true));
-
-        MongoDatabase admin = mongoClient.getDatabase(ADMIN_DATABASE);
-        // shard test collection with sharded key { f0: 1, f1: 1 }
-        admin.runCommand(
-                BsonDocument.parse(String.format("{ enableSharding: '%s'}", 
TEST_DATABASE)));
-        admin.runCommand(
-                BsonDocument.parse(
-                        String.format(
-                                "{ shardCollection : '%s.%s', key : { f0: 1, 
f1: 1 }}",
-                                TEST_DATABASE, TEST_SHARDED_COLLECTION)));
+        Bson indexKeys = BsonDocument.parse("{ f0: 1, f1: 1 }");
+        MongoTestUtil.createIndex(
+                mongoClient,
+                TEST_DATABASE,
+                TEST_SHARDED_COLLECTION,
+                indexKeys,
+                new IndexOptions().unique(true));
+        MongoTestUtil.shardCollection(
+                mongoClient, TEST_DATABASE, TEST_SHARDED_COLLECTION, 
indexKeys);
 
         // create hashed index {f1: 'hashed'}.
-        mongoClient
-                .getDatabase(TEST_DATABASE)
-                .getCollection(TEST_HASHED_KEY_SHARDED_COLLECTION)
-                .createIndex(BsonDocument.parse("{ f1: 'hashed' }"), new 
IndexOptions());
-
-        // shard test collection with hashed sharded key { f1: 'hashed' }
-        admin.runCommand(
-                BsonDocument.parse(
-                        String.format(
-                                "{ enableSharding: '%s'}", 
TEST_HASHED_KEY_SHARDED_COLLECTION)));
-        admin.runCommand(
-                BsonDocument.parse(
-                        String.format(
-                                "{ shardCollection : '%s.%s', key : { f1: 
'hashed' }}",
-                                TEST_DATABASE, 
TEST_HASHED_KEY_SHARDED_COLLECTION)));
+        Bson hashedIndexKeys = BsonDocument.parse("{ f1: 'hashed' }");
+        MongoTestUtil.createIndex(
+                mongoClient,
+                TEST_DATABASE,
+                TEST_HASHED_KEY_SHARDED_COLLECTION,
+                hashedIndexKeys,
+                new IndexOptions());
+        MongoTestUtil.shardCollection(
+                mongoClient, TEST_DATABASE, 
TEST_HASHED_KEY_SHARDED_COLLECTION, hashedIndexKeys);
     }
 
     @AfterAll
diff --git 
a/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableFactoryTest.java
 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableFactoryTest.java
index e2337f3..c9af0fc 100644
--- 
a/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableFactoryTest.java
+++ 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoDynamicTableFactoryTest.java
@@ -102,9 +102,8 @@ public class MongoDynamicTableFactoryTest {
                         getConnectionOptions(),
                         MongoWriteOptions.builder().build(),
                         null,
-                        SCHEMA.getPrimaryKey().isPresent(),
-                        SCHEMA.toPhysicalRowDataType(),
-                        MongoKeyExtractor.createKeyExtractor(SCHEMA));
+                        SCHEMA,
+                        new String[0]);
         assertThat(actualSink).isEqualTo(expectedSink);
     }
 
@@ -191,12 +190,7 @@ public class MongoDynamicTableFactoryTest {
 
         MongoDynamicTableSink expected =
                 new MongoDynamicTableSink(
-                        connectionOptions,
-                        writeOptions,
-                        null,
-                        SCHEMA.getPrimaryKey().isPresent(),
-                        SCHEMA.toPhysicalRowDataType(),
-                        MongoKeyExtractor.createKeyExtractor(SCHEMA));
+                        connectionOptions, writeOptions, null, SCHEMA, new 
String[0]);
 
         assertThat(actual).isEqualTo(expected);
     }
@@ -214,12 +208,7 @@ public class MongoDynamicTableFactoryTest {
 
         MongoDynamicTableSink expected =
                 new MongoDynamicTableSink(
-                        connectionOptions,
-                        writeOptions,
-                        2,
-                        SCHEMA.getPrimaryKey().isPresent(),
-                        SCHEMA.toPhysicalRowDataType(),
-                        MongoKeyExtractor.createKeyExtractor(SCHEMA));
+                        connectionOptions, writeOptions, 2, SCHEMA, new 
String[0]);
 
         assertThat(actual).isEqualTo(expected);
     }
diff --git 
a/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoPartitionedTableSinkITCase.java
 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoPartitionedTableSinkITCase.java
new file mode 100644
index 0000000..9e6c98d
--- /dev/null
+++ 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoPartitionedTableSinkITCase.java
@@ -0,0 +1,290 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.connector.mongodb.table;
+
+import org.apache.flink.connector.mongodb.testutils.MongoShardedContainers;
+import org.apache.flink.connector.mongodb.testutils.MongoTestUtil;
+import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.EnvironmentSettings;
+import org.apache.flink.table.api.TableEnvironment;
+import org.apache.flink.table.expressions.Expression;
+import org.apache.flink.test.junit5.MiniClusterExtension;
+
+import com.mongodb.client.MongoClient;
+import com.mongodb.client.MongoClients;
+import com.mongodb.client.MongoCollection;
+import com.mongodb.client.model.Filters;
+import com.mongodb.client.model.IndexOptions;
+import org.bson.BsonDocument;
+import org.bson.Document;
+import org.bson.conversions.Bson;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+import org.testcontainers.containers.Network;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import static 
org.apache.flink.connector.mongodb.testutils.MongoTestUtil.getConnectorSql;
+import static org.apache.flink.table.api.Expressions.nullOf;
+import static org.apache.flink.table.api.Expressions.row;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+class MongoPartitionedTableSinkITCase {
+
+    @RegisterExtension
+    static final MongoShardedContainers MONGO_SHARDED_CONTAINER =
+            MongoTestUtil.createMongoDBShardedContainers(Network.newNetwork());
+
+    @RegisterExtension
+    static final MiniClusterExtension MINI_CLUSTER_RESOURCE =
+            new MiniClusterExtension(
+                    new MiniClusterResourceConfiguration.Builder()
+                            .setNumberTaskManagers(1)
+                            .build());
+
+    private MongoClient mongoClient;
+
+    @BeforeEach
+    void setUp() {
+        mongoClient = 
MongoClients.create(MONGO_SHARDED_CONTAINER.getConnectionString());
+    }
+
+    @AfterEach
+    void tearDown() {
+        if (mongoClient != null) {
+            mongoClient.close();
+        }
+    }
+
+    @Test
+    void testSinkIntoPartitionedTable() throws Exception {
+        String database = "test";
+        String collection = "sink_into_sharded_collection";
+
+        // sink into sharded collection by unique index { b: 1, c: 1 }.
+        Bson hashedIndex = BsonDocument.parse("{ b: 1, c: 1 }");
+        MongoTestUtil.createIndex(
+                mongoClient, database, collection, hashedIndex, new 
IndexOptions().unique(true));
+        MongoTestUtil.shardCollection(mongoClient, database, collection, 
hashedIndex);
+
+        List<Expression> testValues =
+                Arrays.asList(
+                        row(1L, nullOf(DataTypes.BOOLEAN()), "ABCDEF", 12.12d, 
4),
+                        row(1L, nullOf(DataTypes.BOOLEAN()), "ABCDEF", 
12.123d, 5));
+        List<String> primaryKeys = Collections.singletonList("a");
+
+        TableEnvironment tEnv = 
TableEnvironment.create(EnvironmentSettings.inStreamingMode());
+        createPartitionedTable(tEnv, database, collection, primaryKeys, 
Arrays.asList("b", "c"));
+
+        tEnv.fromValues(testValues).executeInsert("mongo_sink").await();
+
+        MongoCollection<Document> coll =
+                mongoClient.getDatabase(database).getCollection(collection);
+
+        Document expected = new Document();
+        expected.put("_id", 1L);
+        expected.put("a", 1L);
+        expected.put("b", null);
+        expected.put("c", "ABCDEF");
+        expected.put("d", 12.123d);
+        expected.put("e", 5);
+
+        assertThat(coll.find(Filters.eq("_id", 1L))).containsExactly(expected);
+    }
+
+    @Test
+    void testSinkIntoPartitionedTableWithMutableShardKey() {
+        String database = "test";
+        String collection = "sink_into_mutable_sharded_collection";
+
+        // sink into sharded collection by unique index { b: 1, c: 1 }.
+        Bson hashedIndex = BsonDocument.parse("{ b: 1, c: 1 }");
+        MongoTestUtil.createIndex(
+                mongoClient, database, collection, hashedIndex, new 
IndexOptions().unique(true));
+        MongoTestUtil.shardCollection(mongoClient, database, collection, 
hashedIndex);
+
+        List<Expression> testValues =
+                Arrays.asList(
+                        row(1L, false, "ABCDEF", 12.12d, 4), row(1L, true, 
"ABCDEF", 12.123d, 5));
+        List<String> primaryKeys = Collections.singletonList("a");
+
+        TableEnvironment tEnv = 
TableEnvironment.create(EnvironmentSettings.inStreamingMode());
+        createPartitionedTable(tEnv, database, collection, primaryKeys, 
Arrays.asList("b", "c"));
+
+        // update the shard key value should be failed.
+        assertThatThrownBy(() -> 
tEnv.fromValues(testValues).executeInsert("mongo_sink").await())
+                .hasStackTraceContaining("Writing records to MongoDB failed");
+    }
+
+    @Test
+    void testSinkIntoHashedPartitionedTable() throws Exception {
+        String database = "test";
+        String collection = "sink_into_hashed_sharded_collection";
+
+        // sink into sharded collection by hashed index { c: 'hashed' }.
+        Bson hashedIndex = BsonDocument.parse("{ c: 'hashed' }");
+        MongoTestUtil.createIndex(
+                mongoClient, database, collection, hashedIndex, new 
IndexOptions());
+        MongoTestUtil.shardCollection(mongoClient, database, collection, 
hashedIndex);
+
+        List<Expression> testValues =
+                Arrays.asList(
+                        row(2L, true, "ABCDEF", 12.12d, 4), row(2L, false, 
"ABCDEF", 12.123d, 5));
+
+        TableEnvironment tEnv = 
TableEnvironment.create(EnvironmentSettings.inStreamingMode());
+        createPartitionedTable(
+                tEnv,
+                database,
+                collection,
+                Collections.singletonList("a"),
+                Collections.singletonList("c"));
+
+        tEnv.fromValues(testValues).executeInsert("mongo_sink").await();
+
+        MongoCollection<Document> coll =
+                mongoClient.getDatabase(database).getCollection(collection);
+
+        Document expected = new Document();
+        expected.put("_id", 2L);
+        expected.put("a", 2L);
+        expected.put("b", false);
+        expected.put("c", "ABCDEF");
+        expected.put("d", 12.123d);
+        expected.put("e", 5);
+
+        assertThat(coll.find(Filters.eq("_id", 2L))).containsExactly(expected);
+    }
+
+    @Test
+    void testSinkIntoPartitionedTableAll() throws Exception {
+        String database = "test";
+        String collection = "sink_into_sharded_collection_all";
+
+        // sink into static sharded collection by unique index { b: 1, c: 1 }.
+        Bson hashedIndex = BsonDocument.parse("{ b: 1, c: 1 }");
+        MongoTestUtil.createIndex(
+                mongoClient, database, collection, hashedIndex, new 
IndexOptions().unique(true));
+        MongoTestUtil.shardCollection(mongoClient, database, collection, 
hashedIndex);
+
+        TableEnvironment tEnv = 
TableEnvironment.create(EnvironmentSettings.inStreamingMode());
+        createPartitionedTable(
+                tEnv,
+                database,
+                collection,
+                Collections.singletonList("a"),
+                Arrays.asList("b", "c"));
+
+        tEnv.executeSql(
+                        "INSERT INTO mongo_sink PARTITION (b='true', 
c='ABCDEF') SELECT 3, 12.1234, 5")
+                .await();
+        tEnv.executeSql(
+                        "INSERT INTO mongo_sink PARTITION (b='true', 
c='ABCDEF') SELECT 3, 12.12345, 6")
+                .await();
+
+        MongoCollection<Document> coll =
+                mongoClient.getDatabase(database).getCollection(collection);
+
+        Document expected = new Document();
+        expected.put("_id", 3L);
+        expected.put("a", 3L);
+        expected.put("b", true);
+        expected.put("c", "ABCDEF");
+        expected.put("d", 12.12345d);
+        expected.put("e", 6);
+
+        assertThat(coll.find(Filters.eq("_id", 3L))).containsExactly(expected);
+    }
+
+    @Test
+    void testSinkIntoPartitionedTablePart() throws Exception {
+        String database = "test";
+        String collection = "sink_into_sharded_collection_part";
+
+        // sink into static sharded collection by unique index { b: 1, c: 1 }.
+        Bson hashedIndex = BsonDocument.parse("{ c: 1, b: 1 }");
+        MongoTestUtil.createIndex(
+                mongoClient, database, collection, hashedIndex, new 
IndexOptions().unique(true));
+        MongoTestUtil.shardCollection(mongoClient, database, collection, 
hashedIndex);
+
+        TableEnvironment tEnv = 
TableEnvironment.create(EnvironmentSettings.inStreamingMode());
+        createPartitionedTable(
+                tEnv,
+                database,
+                collection,
+                Collections.singletonList("a"),
+                Arrays.asList("c", "b"));
+
+        tEnv.executeSql(
+                        "INSERT INTO mongo_sink PARTITION (c='ABCDEFG') SELECT 
4, false, 12.12345, 6")
+                .await();
+        tEnv.executeSql(
+                        "INSERT INTO mongo_sink PARTITION (c='ABCDEFG') SELECT 
4, false, 12.123456, 7")
+                .await();
+
+        MongoCollection<Document> coll =
+                mongoClient.getDatabase(database).getCollection(collection);
+
+        Document expected = new Document();
+        expected.put("_id", 4L);
+        expected.put("a", 4L);
+        expected.put("b", false);
+        expected.put("c", "ABCDEFG");
+        expected.put("d", 12.123456d);
+        expected.put("e", 7);
+
+        assertThat(coll.find(Filters.eq("_id", 4L))).containsExactly(expected);
+    }
+
+    private static void createPartitionedTable(
+            TableEnvironment tEnv,
+            String database,
+            String collection,
+            List<String> primaryKeys,
+            Collection<String> shardKeys) {
+
+        tEnv.executeSql(
+                String.format(
+                        "CREATE TABLE mongo_sink ("
+                                + "a BIGINT NOT NULL,\n"
+                                + "b BOOLEAN,\n"
+                                + "c STRING NOT NULL,\n"
+                                + "d DOUBLE,\n"
+                                + "e INT NOT NULL,\n"
+                                + "PRIMARY KEY (%s) NOT ENFORCED\n"
+                                + ") "
+                                + "PARTITIONED BY (%s)\n"
+                                + "WITH (%s)",
+                        formatKeys(primaryKeys),
+                        formatKeys(shardKeys),
+                        getConnectorSql(
+                                database,
+                                collection,
+                                
MONGO_SHARDED_CONTAINER.getConnectionString())));
+    }
+
+    private static String formatKeys(Collection<String> fieldNames) {
+        return String.join(",", fieldNames);
+    }
+}
diff --git 
a/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoKeyExtractorTest.java
 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoPrimaryKeyExtractorTest.java
similarity index 90%
rename from 
flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoKeyExtractorTest.java
rename to 
flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoPrimaryKeyExtractorTest.java
index a0d7df6..3214017 100644
--- 
a/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoKeyExtractorTest.java
+++ 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoPrimaryKeyExtractorTest.java
@@ -49,8 +49,8 @@ import java.util.function.Function;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
-/** Tests for {@link MongoKeyExtractor}. */
-public class MongoKeyExtractorTest {
+/** Tests for {@link MongoPrimaryKeyExtractor}. */
+public class MongoPrimaryKeyExtractorTest {
 
     @Test
     public void testSinglePrimaryKey() {
@@ -62,7 +62,8 @@ public class MongoKeyExtractorTest {
                         Collections.emptyList(),
                         UniqueConstraint.primaryKey("pk", 
Collections.singletonList("a")));
 
-        Function<RowData, BsonValue> keyExtractor = 
MongoKeyExtractor.createKeyExtractor(schema);
+        Function<RowData, BsonValue> keyExtractor =
+                MongoPrimaryKeyExtractor.createPrimaryKeyExtractor(schema);
 
         BsonValue key = keyExtractor.apply(GenericRowData.of(12L, 
StringData.fromString("ABCD")));
         assertThat(key).isEqualTo(new BsonInt64(12L));
@@ -78,7 +79,8 @@ public class MongoKeyExtractorTest {
                         Collections.emptyList(),
                         UniqueConstraint.primaryKey("pk", 
Collections.singletonList("_id")));
 
-        Function<RowData, BsonValue> keyExtractor = 
MongoKeyExtractor.createKeyExtractor(schema);
+        Function<RowData, BsonValue> keyExtractor =
+                MongoPrimaryKeyExtractor.createPrimaryKeyExtractor(schema);
 
         ObjectId objectId = new ObjectId();
         BsonValue key =
@@ -99,7 +101,7 @@ public class MongoKeyExtractorTest {
                         Collections.emptyList(),
                         UniqueConstraint.primaryKey("pk", 
Collections.singletonList("_id, a")));
 
-        assertThatThrownBy(() -> MongoKeyExtractor.createKeyExtractor(schema0))
+        assertThatThrownBy(() -> 
MongoPrimaryKeyExtractor.createPrimaryKeyExtractor(schema0))
                 .isInstanceOf(IllegalArgumentException.class)
                 .hasMessageMatching("Ambiguous keys .*");
 
@@ -111,7 +113,7 @@ public class MongoKeyExtractorTest {
                         Collections.emptyList(),
                         null);
 
-        assertThatThrownBy(() -> MongoKeyExtractor.createKeyExtractor(schema1))
+        assertThatThrownBy(() -> 
MongoPrimaryKeyExtractor.createPrimaryKeyExtractor(schema1))
                 .isInstanceOf(IllegalArgumentException.class)
                 .hasMessageMatching("Ambiguous keys .*");
     }
@@ -126,7 +128,8 @@ public class MongoKeyExtractorTest {
                         Collections.emptyList(),
                         null);
 
-        Function<RowData, BsonValue> keyExtractor = 
MongoKeyExtractor.createKeyExtractor(schema);
+        Function<RowData, BsonValue> keyExtractor =
+                MongoPrimaryKeyExtractor.createPrimaryKeyExtractor(schema);
 
         BsonValue key = keyExtractor.apply(GenericRowData.of(12L, 
StringData.fromString("ABCD")));
         assertThat(key).isNull();
@@ -143,7 +146,8 @@ public class MongoKeyExtractorTest {
                         Collections.emptyList(),
                         UniqueConstraint.primaryKey("pk", Arrays.asList("a", 
"b")));
 
-        Function<RowData, BsonValue> keyExtractor = 
MongoKeyExtractor.createKeyExtractor(schema);
+        Function<RowData, BsonValue> keyExtractor =
+                MongoPrimaryKeyExtractor.createPrimaryKeyExtractor(schema);
 
         BsonValue key =
                 keyExtractor.apply(
@@ -177,7 +181,8 @@ public class MongoKeyExtractorTest {
                         UniqueConstraint.primaryKey(
                                 "pk", Arrays.asList("a", "b", "c", "d", "e", 
"f", "g")));
 
-        Function<RowData, BsonValue> keyExtractor = 
MongoKeyExtractor.createKeyExtractor(schema);
+        Function<RowData, BsonValue> keyExtractor =
+                MongoPrimaryKeyExtractor.createPrimaryKeyExtractor(schema);
 
         BsonValue key =
                 keyExtractor.apply(
diff --git 
a/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoShardKeysExtractorTest.java
 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoShardKeysExtractorTest.java
new file mode 100644
index 0000000..48cd208
--- /dev/null
+++ 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/table/MongoShardKeysExtractorTest.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.connector.mongodb.table;
+
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.catalog.Column;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.catalog.UniqueConstraint;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.StringData;
+
+import org.bson.BsonDocument;
+import org.bson.BsonInt64;
+import org.bson.BsonObjectId;
+import org.bson.BsonString;
+import org.bson.types.ObjectId;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.function.Function;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link MongoShardKeysExtractor}. */
+class MongoShardKeysExtractorTest {
+
+    @Test
+    void testSingleShardKey() {
+        ResolvedSchema schema =
+                new ResolvedSchema(
+                        Arrays.asList(
+                                Column.physical("a", 
DataTypes.BIGINT().notNull()),
+                                Column.physical("b", DataTypes.STRING())),
+                        Collections.emptyList(),
+                        UniqueConstraint.primaryKey("pk", 
Collections.singletonList("a")));
+
+        String[] shardKeys = new String[] {"b"};
+
+        Function<RowData, BsonDocument> shardKeysExtractor =
+                MongoShardKeysExtractor.createShardKeysExtractor(schema, 
shardKeys);
+
+        BsonDocument actual =
+                shardKeysExtractor.apply(GenericRowData.of(12L, 
StringData.fromString("ABCD")));
+        assertThat(actual).isEqualTo(new BsonDocument("b", new 
BsonString("ABCD")));
+    }
+
+    @Test
+    void testCompoundShardKey() {
+        ResolvedSchema schema =
+                new ResolvedSchema(
+                        Arrays.asList(
+                                Column.physical("a", 
DataTypes.BIGINT().notNull()),
+                                Column.physical("b", 
DataTypes.STRING().notNull()),
+                                Column.physical("c", DataTypes.BIGINT())),
+                        Collections.emptyList(),
+                        UniqueConstraint.primaryKey("pk", 
Collections.singletonList("a")));
+
+        String[] shardKeys = new String[] {"b", "c"};
+
+        Function<RowData, BsonDocument> shardKeysExtractor =
+                MongoShardKeysExtractor.createShardKeysExtractor(schema, 
shardKeys);
+
+        BsonDocument actual =
+                shardKeysExtractor.apply(
+                        GenericRowData.of(12L, StringData.fromString("ABCD"), 
13L));
+        assertThat(actual)
+                .isEqualTo(
+                        new BsonDocument("b", new BsonString("ABCD"))
+                                .append("c", new BsonInt64(13L)));
+    }
+
+    @Test
+    void testCompoundShardKeyWithObjectId() {
+        ResolvedSchema schema =
+                new ResolvedSchema(
+                        Arrays.asList(
+                                Column.physical("a", 
DataTypes.STRING().notNull()),
+                                Column.physical("b", 
DataTypes.STRING().notNull()),
+                                Column.physical("c", DataTypes.BIGINT())),
+                        Collections.emptyList(),
+                        UniqueConstraint.primaryKey("pk", 
Collections.singletonList("a")));
+
+        String[] shardKeys = new String[] {"a", "b"};
+
+        Function<RowData, BsonDocument> shardKeysExtractor =
+                MongoShardKeysExtractor.createShardKeysExtractor(schema, 
shardKeys);
+
+        ObjectId objectId = new ObjectId();
+        BsonDocument actual =
+                shardKeysExtractor.apply(
+                        GenericRowData.of(
+                                StringData.fromString(objectId.toString()),
+                                StringData.fromString("ABCD"),
+                                13L));
+        assertThat(actual)
+                .isEqualTo(
+                        new BsonDocument("a", new BsonObjectId(objectId))
+                                .append("b", new BsonString("ABCD")));
+    }
+}
diff --git 
a/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/testutils/MongoTestUtil.java
 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/testutils/MongoTestUtil.java
index 246f2bc..e41a71e 100644
--- 
a/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/testutils/MongoTestUtil.java
+++ 
b/flink-connector-mongodb/src/test/java/org/apache/flink/connector/mongodb/testutils/MongoTestUtil.java
@@ -18,10 +18,16 @@
 package org.apache.flink.connector.mongodb.testutils;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.connector.mongodb.table.MongoConnectorOptions;
+import org.apache.flink.table.factories.FactoryUtil;
 
+import com.mongodb.client.MongoClient;
 import com.mongodb.client.MongoCollection;
+import com.mongodb.client.MongoDatabase;
 import com.mongodb.client.model.Filters;
+import com.mongodb.client.model.IndexOptions;
 import org.bson.Document;
+import org.bson.conversions.Bson;
 import org.slf4j.Logger;
 import org.testcontainers.containers.MongoDBContainer;
 import org.testcontainers.containers.Network;
@@ -41,6 +47,12 @@ public class MongoTestUtil {
 
     public static final String MONGO_4_0 = "mongo:4.0.10";
 
+    public static final String ADMIN_DATABASE = "admin";
+    public static final String CONFIG_DATABASE = "config";
+    public static final String SETTINGS_COLLECTION = "settings";
+    public static final String CHUNK_SIZE_FIELD = "chunksize";
+    public static final String VALUE_FIELD = "value";
+
     private MongoTestUtil() {}
 
     /**
@@ -90,4 +102,36 @@ public class MongoTestUtil {
         }
         assertThatIdsAreWritten(coll, ids);
     }
+
+    public static String getConnectorSql(
+            String database, String collection, String connectionString) {
+        return String.format("'%s'='%s',\n", FactoryUtil.CONNECTOR.key(), 
"mongodb")
+                + String.format("'%s'='%s',\n", 
MongoConnectorOptions.URI.key(), connectionString)
+                + String.format("'%s'='%s',\n", 
MongoConnectorOptions.DATABASE.key(), database)
+                + String.format("'%s'='%s'\n", 
MongoConnectorOptions.COLLECTION.key(), collection);
+    }
+
+    public static void createIndex(
+            MongoClient mongoClient,
+            String databaseName,
+            String collectionName,
+            Bson keys,
+            IndexOptions indexOptions) {
+        mongoClient
+                .getDatabase(databaseName)
+                .getCollection(collectionName)
+                .createIndex(keys, indexOptions);
+    }
+
+    public static void shardCollection(
+            MongoClient mongoClient, String databaseName, String 
collectionName, Bson keys) {
+        MongoDatabase admin = mongoClient.getDatabase(ADMIN_DATABASE);
+        Document enableShardingCommand = new Document("enableSharding", 
databaseName);
+        admin.runCommand(enableShardingCommand);
+
+        Document shardCollectionCommand =
+                new Document("shardCollection", databaseName + "." + 
collectionName)
+                        .append("key", keys);
+        admin.runCommand(shardCollectionCommand);
+    }
 }

Reply via email to