This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6cd15e7bed1 BigTable write SchemaTransform and Python wrapper (#27146)
6cd15e7bed1 is described below

commit 6cd15e7bed1206cea55a7fd02a715e595ed2a966
Author: Ahmed Abualsaud <[email protected]>
AuthorDate: Wed Jun 28 01:14:17 2023 +0000

    BigTable write SchemaTransform and Python wrapper (#27146)
    
    * python wrapper
    
    * schematransform; started some tests
    
    * schematransform and tests done
    
    * python wrapper and tests done
    
    * add tests for _DirectRowMutationsToBeamRow DoFn
    
    * use expansion service
    
    * style fix
    
    * style fix
    
    * use secrets token in instance name to prevent race condition between py37 
& py311 tests
    
    * use typing typehints
    
    * style fix
    
    * add missing import
    
    * test fix
    
    * resolve merge issues
    
    * style fix
    
    * raise timeout of dataflow tests
    
    * fix setting timeout
    
    * address reviewer comments
    
    * address reviewer comments
    
    * trivial change
    
    * remove unused imports; lint fixes
    
    * add unittest main call
---
 ...Commit_Python_CrossLanguage_Gcp_Dataflow.groovy |   2 +-
 .../BigtableWriteSchemaTransformProvider.java      | 253 +++++++++++++
 .../BigtableWriteSchemaTransformProviderIT.java    | 413 +++++++++++++++++++++
 sdks/python/apache_beam/io/gcp/bigtableio.py       | 135 +++++--
 .../apache_beam/io/gcp/bigtableio_it_test.py       | 377 +++++++++++++++++++
 sdks/python/apache_beam/io/gcp/bigtableio_test.py  | 247 +++++++-----
 6 files changed, 1304 insertions(+), 123 deletions(-)

diff --git 
a/.test-infra/jenkins/job_PostCommit_Python_CrossLanguage_Gcp_Dataflow.groovy 
b/.test-infra/jenkins/job_PostCommit_Python_CrossLanguage_Gcp_Dataflow.groovy
index d1676fcae46..d1ee27088c7 100644
--- 
a/.test-infra/jenkins/job_PostCommit_Python_CrossLanguage_Gcp_Dataflow.groovy
+++ 
b/.test-infra/jenkins/job_PostCommit_Python_CrossLanguage_Gcp_Dataflow.groovy
@@ -32,7 +32,7 @@ 
PostcommitJobBuilder.postCommitJob('beam_PostCommit_Python_Xlang_Gcp_Dataflow',
 
 
       // Set common parameters.
-      commonJobProperties.setTopLevelMainJobProperties(delegate)
+      commonJobProperties.setTopLevelMainJobProperties(delegate, 'master', 180)
 
 
       // Publish all test results to Jenkins
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteSchemaTransformProvider.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteSchemaTransformProvider.java
new file mode 100644
index 00000000000..f57ea46dcdb
--- /dev/null
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteSchemaTransformProvider.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable;
+
+import static java.util.Optional.ofNullable;
+import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+
+import com.google.auto.service.AutoService;
+import com.google.auto.value.AutoValue;
+import com.google.bigtable.v2.Mutation;
+import com.google.bigtable.v2.TimestampRange;
+import com.google.protobuf.ByteString;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import 
org.apache.beam.sdk.io.gcp.bigtable.BigtableWriteSchemaTransformProvider.BigtableWriteSchemaTransformConfiguration;
+import org.apache.beam.sdk.schemas.AutoValueSchema;
+import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
+import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
+import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Longs;
+
+/**
+ * An implementation of {@link TypedSchemaTransformProvider} for Bigtable 
Write jobs configured via
+ * {@link BigtableWriteSchemaTransformConfiguration}.
+ *
+ * <p><b>Internal only:</b> This class is actively being worked on, and it 
will likely change. We
+ * provide no backwards compatibility guarantees, and it should not be 
implemented outside the Beam
+ * repository.
+ */
+@AutoService(SchemaTransformProvider.class)
+public class BigtableWriteSchemaTransformProvider
+    extends 
TypedSchemaTransformProvider<BigtableWriteSchemaTransformConfiguration> {
+
+  private static final String INPUT_TAG = "input";
+
+  @Override
+  protected Class<BigtableWriteSchemaTransformConfiguration> 
configurationClass() {
+    return BigtableWriteSchemaTransformConfiguration.class;
+  }
+
+  @Override
+  protected SchemaTransform from(BigtableWriteSchemaTransformConfiguration 
configuration) {
+    return new BigtableWriteSchemaTransform(configuration);
+  }
+
+  @Override
+  public String identifier() {
+    return "beam:schematransform:org.apache.beam:bigtable_write:v1";
+  }
+
+  @Override
+  public List<String> inputCollectionNames() {
+    return Collections.singletonList(INPUT_TAG);
+  }
+
+  @Override
+  public List<String> outputCollectionNames() {
+    return Collections.emptyList();
+  }
+
+  /** Configuration for writing to Bigtable. */
+  @DefaultSchema(AutoValueSchema.class)
+  @AutoValue
+  public abstract static class BigtableWriteSchemaTransformConfiguration {
+    /** Instantiates a {@link 
BigtableWriteSchemaTransformConfiguration.Builder} instance. */
+    public static Builder builder() {
+      return new 
AutoValue_BigtableWriteSchemaTransformProvider_BigtableWriteSchemaTransformConfiguration
+          .Builder();
+    }
+
+    /** Validates the configuration object. */
+    public void validate() {
+      String invalidConfigMessage =
+          "Invalid Bigtable Write configuration: %s should be a non-empty 
String";
+      checkArgument(!this.getTableId().isEmpty(), 
String.format(invalidConfigMessage, "table"));
+      checkArgument(
+          !this.getInstanceId().isEmpty(), String.format(invalidConfigMessage, 
"instance"));
+      checkArgument(!this.getProjectId().isEmpty(), 
String.format(invalidConfigMessage, "project"));
+    }
+
+    public abstract String getTableId();
+
+    public abstract String getInstanceId();
+
+    public abstract String getProjectId();
+
+    /** Builder for the {@link BigtableWriteSchemaTransformConfiguration}. */
+    @AutoValue.Builder
+    public abstract static class Builder {
+      public abstract Builder setTableId(String tableId);
+
+      public abstract Builder setInstanceId(String instanceId);
+
+      public abstract Builder setProjectId(String projectId);
+
+      /** Builds a {@link BigtableWriteSchemaTransformConfiguration} instance. 
*/
+      public abstract BigtableWriteSchemaTransformConfiguration build();
+    }
+  }
+
+  /**
+   * A {@link SchemaTransform} for Bigtable writes, configured with {@link
+   * BigtableWriteSchemaTransformConfiguration} and instantiated by {@link
+   * BigtableWriteSchemaTransformProvider}.
+   */
+  private static class BigtableWriteSchemaTransform
+      extends PTransform<PCollectionRowTuple, PCollectionRowTuple> implements 
SchemaTransform {
+    private final BigtableWriteSchemaTransformConfiguration configuration;
+
+    BigtableWriteSchemaTransform(BigtableWriteSchemaTransformConfiguration 
configuration) {
+      configuration.validate();
+      this.configuration = configuration;
+    }
+
+    @Override
+    public PCollectionRowTuple expand(PCollectionRowTuple input) {
+      checkArgument(
+          input.has(INPUT_TAG),
+          String.format(
+              "Could not find expected input [%s] to %s.", INPUT_TAG, 
getClass().getSimpleName()));
+
+      PCollection<Row> beamRowMutations = input.get(INPUT_TAG);
+      PCollection<KV<ByteString, Iterable<Mutation>>> bigtableMutations =
+          beamRowMutations.apply(MapElements.via(new 
GetMutationsFromBeamRow()));
+
+      bigtableMutations.apply(
+          BigtableIO.write()
+              .withTableId(configuration.getTableId())
+              .withInstanceId(configuration.getInstanceId())
+              .withProjectId(configuration.getProjectId()));
+
+      return PCollectionRowTuple.empty(input.getPipeline());
+    }
+
+    @Override
+    public PTransform<PCollectionRowTuple, PCollectionRowTuple> 
buildTransform() {
+      return this;
+    }
+  }
+
+  public static class GetMutationsFromBeamRow
+      extends SimpleFunction<Row, KV<ByteString, Iterable<Mutation>>> {
+    @Override
+    public KV<ByteString, Iterable<Mutation>> apply(Row row) {
+      ByteString key = 
ByteString.copyFrom(ofNullable(row.getBytes("key")).get());
+      List<Map<String, byte[]>> beamRowMutations =
+          (List) ofNullable(row.getArray("mutations")).get();
+
+      List<Mutation> mutations = new ArrayList<>(beamRowMutations.size());
+
+      for (Map<String, byte[]> mutation : beamRowMutations) {
+        Mutation bigtableMutation;
+        switch (new String(ofNullable(mutation.get("type")).get(), 
StandardCharsets.UTF_8)) {
+          case "SetCell":
+            Mutation.SetCell.Builder setMutation =
+                Mutation.SetCell.newBuilder()
+                    
.setValue(ByteString.copyFrom(ofNullable(mutation.get("value")).get()))
+                    .setColumnQualifier(
+                        
ByteString.copyFrom(ofNullable(mutation.get("column_qualifier")).get()))
+                    .setFamilyNameBytes(
+                        
ByteString.copyFrom(ofNullable(mutation.get("family_name")).get()));
+            if (mutation.containsKey("timestamp_micros")) {
+              setMutation =
+                  setMutation.setTimestampMicros(
+                      
Longs.fromByteArray(ofNullable(mutation.get("timestamp_micros")).get()));
+            }
+            bigtableMutation = 
Mutation.newBuilder().setSetCell(setMutation.build()).build();
+            break;
+          case "DeleteFromColumn":
+            Mutation.DeleteFromColumn.Builder deleteMutation =
+                Mutation.DeleteFromColumn.newBuilder()
+                    .setColumnQualifier(
+                        
ByteString.copyFrom(ofNullable(mutation.get("column_qualifier")).get()))
+                    .setFamilyNameBytes(
+                        
ByteString.copyFrom(ofNullable(mutation.get("family_name")).get()));
+
+            // set timestamp range if applicable
+            if (mutation.containsKey("start_timestamp_micros")
+                || mutation.containsKey("end_timestamp_micros")) {
+              TimestampRange.Builder timeRange = TimestampRange.newBuilder();
+              if (mutation.containsKey("start_timestamp_micros")) {
+                Long startMicros =
+                    
ByteBuffer.wrap(ofNullable(mutation.get("start_timestamp_micros")).get())
+                        .getLong();
+                timeRange.setStartTimestampMicros(startMicros);
+              }
+              if (mutation.containsKey("end_timestamp_micros")) {
+                Long endMicros =
+                    
ByteBuffer.wrap(ofNullable(mutation.get("end_timestamp_micros")).get())
+                        .getLong();
+                timeRange.setEndTimestampMicros(endMicros);
+              }
+              deleteMutation.setTimeRange(timeRange.build());
+            }
+            bigtableMutation =
+                
Mutation.newBuilder().setDeleteFromColumn(deleteMutation.build()).build();
+            break;
+          case "DeleteFromFamily":
+            bigtableMutation =
+                Mutation.newBuilder()
+                    .setDeleteFromFamily(
+                        Mutation.DeleteFromFamily.newBuilder()
+                            .setFamilyNameBytes(
+                                
ByteString.copyFrom(ofNullable(mutation.get("family_name")).get()))
+                            .build())
+                    .build();
+            break;
+          case "DeleteFromRow":
+            bigtableMutation =
+                Mutation.newBuilder()
+                    
.setDeleteFromRow(Mutation.DeleteFromRow.newBuilder().build())
+                    .build();
+            break;
+          default:
+            throw new RuntimeException(
+                String.format(
+                    "Unexpected mutation type [%s]: %s",
+                    Arrays.toString(ofNullable(mutation.get("type")).get()), 
mutation));
+        }
+        mutations.add(bigtableMutation);
+      }
+      return KV.of(key, mutations);
+    }
+  }
+}
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteSchemaTransformProviderIT.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteSchemaTransformProviderIT.java
new file mode 100644
index 00000000000..2af3153215c
--- /dev/null
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableWriteSchemaTransformProviderIT.java
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.bigtable;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+
+import com.google.api.gax.rpc.NotFoundException;
+import com.google.cloud.bigtable.admin.v2.BigtableTableAdminClient;
+import com.google.cloud.bigtable.admin.v2.BigtableTableAdminSettings;
+import com.google.cloud.bigtable.admin.v2.models.CreateTableRequest;
+import com.google.cloud.bigtable.data.v2.BigtableDataClient;
+import com.google.cloud.bigtable.data.v2.BigtableDataSettings;
+import com.google.cloud.bigtable.data.v2.models.Query;
+import com.google.cloud.bigtable.data.v2.models.RowCell;
+import com.google.cloud.bigtable.data.v2.models.RowMutation;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
+import 
org.apache.beam.sdk.io.gcp.bigtable.BigtableWriteSchemaTransformProvider.BigtableWriteSchemaTransformConfiguration;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollectionRowTuple;
+import org.apache.beam.sdk.values.Row;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Longs;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class BigtableWriteSchemaTransformProviderIT {
+  @Rule public final transient TestPipeline p = TestPipeline.create();
+
+  private static final String COLUMN_FAMILY_NAME_1 = "test_cf_1";
+  private static final String COLUMN_FAMILY_NAME_2 = "test_cf_2";
+  private BigtableTableAdminClient tableAdminClient;
+  private BigtableDataClient dataClient;
+  private String tableId = 
String.format("BigtableWriteIT-%tF-%<tH-%<tM-%<tS-%<tL", new Date());
+  private String projectId;
+  private String instanceId;
+  private PTransform<PCollectionRowTuple, PCollectionRowTuple> writeTransform;
+  private static final Schema SCHEMA =
+      Schema.builder()
+          .addByteArrayField("key")
+          .addArrayField(
+              "mutations", Schema.FieldType.map(Schema.FieldType.STRING, 
Schema.FieldType.BYTES))
+          .build();
+
+  @Test
+  public void testInvalidConfigs() {
+    System.out.println(writeTransform.getName());
+    // Properties cannot be empty (project, instance, and table)
+    List<BigtableWriteSchemaTransformConfiguration.Builder> invalidConfigs =
+        Arrays.asList(
+            BigtableWriteSchemaTransformConfiguration.builder()
+                .setProjectId("project")
+                .setInstanceId("instance")
+                .setTableId(""),
+            BigtableWriteSchemaTransformConfiguration.builder()
+                .setProjectId("")
+                .setInstanceId("instance")
+                .setTableId("table"),
+            BigtableWriteSchemaTransformConfiguration.builder()
+                .setProjectId("project")
+                .setInstanceId("")
+                .setTableId("table"));
+
+    for (BigtableWriteSchemaTransformConfiguration.Builder config : 
invalidConfigs) {
+      assertThrows(
+          IllegalArgumentException.class,
+          () -> {
+            config.build().validate();
+          });
+    }
+  }
+
+  @Before
+  public void setup() throws Exception {
+    BigtableTestOptions options =
+        TestPipeline.testingPipelineOptions().as(BigtableTestOptions.class);
+    projectId = options.as(GcpOptions.class).getProject();
+    instanceId = options.getInstanceId();
+
+    BigtableDataSettings settings =
+        
BigtableDataSettings.newBuilder().setProjectId(projectId).setInstanceId(instanceId).build();
+    // Creates a bigtable data client.
+    dataClient = BigtableDataClient.create(settings);
+
+    BigtableTableAdminSettings adminSettings =
+        BigtableTableAdminSettings.newBuilder()
+            .setProjectId(projectId)
+            .setInstanceId(instanceId)
+            .build();
+    tableAdminClient = BigtableTableAdminClient.create(adminSettings);
+
+    // set up the table with some pre-written rows to test our mutations on.
+    // each test is independent of the others
+    if (!tableAdminClient.exists(tableId)) {
+      CreateTableRequest createTableRequest =
+          CreateTableRequest.of(tableId)
+              .addFamily(COLUMN_FAMILY_NAME_1)
+              .addFamily(COLUMN_FAMILY_NAME_2);
+      tableAdminClient.createTable(createTableRequest);
+    }
+
+    BigtableWriteSchemaTransformConfiguration config =
+        BigtableWriteSchemaTransformConfiguration.builder()
+            .setProjectId(projectId)
+            .setInstanceId(instanceId)
+            .setTableId(tableId)
+            .build();
+    writeTransform = new 
BigtableWriteSchemaTransformProvider().from(config).buildTransform();
+  }
+
+  @After
+  public void tearDown() {
+    try {
+      tableAdminClient.deleteTable(tableId);
+      System.out.printf("Table %s deleted successfully%n", tableId);
+    } catch (NotFoundException e) {
+      System.err.println("Failed to delete a non-existent table: " + 
e.getMessage());
+    }
+    dataClient.close();
+    tableAdminClient.close();
+  }
+
+  @Test
+  public void testSetMutationsExistingColumn() {
+    RowMutation rowMutation =
+        RowMutation.create(tableId, "key-1")
+            .setCell(COLUMN_FAMILY_NAME_1, "col_a", "val-1-a")
+            .setCell(COLUMN_FAMILY_NAME_2, "col_c", "val-1-c");
+    dataClient.mutateRow(rowMutation);
+
+    List<Map<String, byte[]>> mutations = new ArrayList<>();
+    // mutation to set cell in an existing column
+    mutations.add(
+        ImmutableMap.of(
+            "type", "SetCell".getBytes(StandardCharsets.UTF_8),
+            "value", "new-val-1-a".getBytes(StandardCharsets.UTF_8),
+            "column_qualifier", "col_a".getBytes(StandardCharsets.UTF_8),
+            "family_name", 
COLUMN_FAMILY_NAME_1.getBytes(StandardCharsets.UTF_8)));
+    mutations.add(
+        ImmutableMap.of(
+            "type", "SetCell".getBytes(StandardCharsets.UTF_8),
+            "value", "new-val-1-c".getBytes(StandardCharsets.UTF_8),
+            "column_qualifier", "col_c".getBytes(StandardCharsets.UTF_8),
+            "family_name", 
COLUMN_FAMILY_NAME_2.getBytes(StandardCharsets.UTF_8)));
+    Row mutationRow =
+        Row.withSchema(SCHEMA)
+            .withFieldValue("key", "key-1".getBytes(StandardCharsets.UTF_8))
+            .withFieldValue("mutations", mutations)
+            .build();
+
+    PCollectionRowTuple.of("input", 
p.apply(Create.of(Arrays.asList(mutationRow))))
+        .apply(writeTransform);
+    p.run().waitUntilFinish();
+
+    // get rows from table
+    List<com.google.cloud.bigtable.data.v2.models.Row> rows =
+        
dataClient.readRows(Query.create(tableId)).stream().collect(Collectors.toList());
+    // we should still have only one row with the same key
+    assertEquals(1, rows.size());
+    assertEquals("key-1", rows.get(0).getKey().toStringUtf8());
+
+    // check that we now have two cells in each column we added to and that
+    // the last cell in each column has the updated value
+    com.google.cloud.bigtable.data.v2.models.Row row = rows.get(0);
+    List<RowCell> cellsColA =
+        row.getCells(COLUMN_FAMILY_NAME_1, "col_a").stream()
+            .sorted(RowCell.compareByNative())
+            .collect(Collectors.toList());
+    List<RowCell> cellsColC =
+        row.getCells(COLUMN_FAMILY_NAME_2, "col_c").stream()
+            .sorted(RowCell.compareByNative())
+            .collect(Collectors.toList());
+    assertEquals(2, cellsColA.size());
+    assertEquals(2, cellsColC.size());
+    System.out.println(cellsColA);
+    System.out.println(cellsColC);
+    assertEquals("new-val-1-a", cellsColA.get(1).getValue().toStringUtf8());
+    assertEquals("new-val-1-c", cellsColC.get(1).getValue().toStringUtf8());
+  }
+
+  @Test
+  public void testSetMutationNewColumn() {
+    RowMutation rowMutation =
+        RowMutation.create(tableId, "key-1").setCell(COLUMN_FAMILY_NAME_1, 
"col_a", "val-1-a");
+    dataClient.mutateRow(rowMutation);
+
+    List<Map<String, byte[]>> mutations = new ArrayList<>();
+    // mutation to set cell in a new column
+    mutations.add(
+        ImmutableMap.of(
+            "type", "SetCell".getBytes(StandardCharsets.UTF_8),
+            "value", "new-val-1".getBytes(StandardCharsets.UTF_8),
+            "column_qualifier", "new_col".getBytes(StandardCharsets.UTF_8),
+            "family_name", 
COLUMN_FAMILY_NAME_1.getBytes(StandardCharsets.UTF_8)));
+    Row mutationRow =
+        Row.withSchema(SCHEMA)
+            .withFieldValue("key", "key-1".getBytes(StandardCharsets.UTF_8))
+            .withFieldValue("mutations", mutations)
+            .build();
+
+    PCollectionRowTuple.of("input", 
p.apply(Create.of(Arrays.asList(mutationRow))))
+        .apply(writeTransform);
+    p.run().waitUntilFinish();
+
+    // get rows from table
+    List<com.google.cloud.bigtable.data.v2.models.Row> rows =
+        
dataClient.readRows(Query.create(tableId)).stream().collect(Collectors.toList());
+
+    // we should still have only one row with the same key
+    assertEquals(1, rows.size());
+    assertEquals("key-1", rows.get(0).getKey().toStringUtf8());
+    // check the new column exists with only one cell.
+    // also check cell value is correct
+    com.google.cloud.bigtable.data.v2.models.Row row = rows.get(0);
+    List<RowCell> cellsNewCol = row.getCells(COLUMN_FAMILY_NAME_1, "new_col");
+    assertEquals(1, cellsNewCol.size());
+    assertEquals("new-val-1", cellsNewCol.get(0).getValue().toStringUtf8());
+  }
+
+  @Test
+  public void testDeleteCellsFromColumn() {
+    RowMutation rowMutation =
+        RowMutation.create(tableId, "key-1")
+            .setCell(COLUMN_FAMILY_NAME_1, "col_a", "val-1-a")
+            .setCell(COLUMN_FAMILY_NAME_1, "col_b", "val-1-b");
+    dataClient.mutateRow(rowMutation);
+    // write two cells in col_a. both should get deleted
+    rowMutation =
+        RowMutation.create(tableId, "key-1").setCell(COLUMN_FAMILY_NAME_1, 
"col_a", "new-val-1-a");
+    dataClient.mutateRow(rowMutation);
+
+    List<Map<String, byte[]>> mutations = new ArrayList<>();
+    // mutation to delete cells from a column
+    mutations.add(
+        ImmutableMap.of(
+            "type", "DeleteFromColumn".getBytes(StandardCharsets.UTF_8),
+            "column_qualifier", "col_a".getBytes(StandardCharsets.UTF_8),
+            "family_name", 
COLUMN_FAMILY_NAME_1.getBytes(StandardCharsets.UTF_8)));
+    Row mutationRow =
+        Row.withSchema(SCHEMA)
+            .withFieldValue("key", "key-1".getBytes(StandardCharsets.UTF_8))
+            .withFieldValue("mutations", mutations)
+            .build();
+
+    PCollectionRowTuple.of("input", 
p.apply(Create.of(Arrays.asList(mutationRow))))
+        .apply(writeTransform);
+    p.run().waitUntilFinish();
+
+    // get rows from table
+    List<com.google.cloud.bigtable.data.v2.models.Row> rows =
+        
dataClient.readRows(Query.create(tableId)).stream().collect(Collectors.toList());
+
+    // we should still have one row with the same key
+    assertEquals(1, rows.size());
+    assertEquals("key-1", rows.get(0).getKey().toStringUtf8());
+    // get cells from this column family. we started with three cells and 
deleted two from one
+    // column.
+    // we should end up with one cell in the column we didn't touch.
+    // check that the remaining cell is indeed from col_b
+    com.google.cloud.bigtable.data.v2.models.Row row = rows.get(0);
+    List<RowCell> cells = row.getCells(COLUMN_FAMILY_NAME_1);
+    assertEquals(1, cells.size());
+    assertEquals("col_b", cells.get(0).getQualifier().toStringUtf8());
+  }
+
+  @Test
+  public void testDeleteCellsFromColumnWithTimestampRange() {
+    // write two cells in one column with different timestamps.
+    RowMutation rowMutation =
+        RowMutation.create(tableId, "key-1")
+            .setCell(COLUMN_FAMILY_NAME_1, "col", 100_000_000, "val");
+    dataClient.mutateRow(rowMutation);
+    rowMutation =
+        RowMutation.create(tableId, "key-1")
+            .setCell(COLUMN_FAMILY_NAME_1, "col", 200_000_000, "new-val");
+    dataClient.mutateRow(rowMutation);
+
+    List<Map<String, byte[]>> mutations = new ArrayList<>();
+    // mutation to delete cells from a column within a timestamp range
+    mutations.add(
+        ImmutableMap.of(
+            "type", "DeleteFromColumn".getBytes(StandardCharsets.UTF_8),
+            "column_qualifier", "col".getBytes(StandardCharsets.UTF_8),
+            "family_name", 
COLUMN_FAMILY_NAME_1.getBytes(StandardCharsets.UTF_8),
+            "start_timestamp_micros", Longs.toByteArray(99_999_999),
+            "end_timestamp_micros", Longs.toByteArray(100_000_001)));
+    Row mutationRow =
+        Row.withSchema(SCHEMA)
+            .withFieldValue("key", "key-1".getBytes(StandardCharsets.UTF_8))
+            .withFieldValue("mutations", mutations)
+            .build();
+
+    PCollectionRowTuple.of("input", 
p.apply(Create.of(Arrays.asList(mutationRow))))
+        .apply(writeTransform);
+    p.run().waitUntilFinish();
+
+    // get rows from table
+    List<com.google.cloud.bigtable.data.v2.models.Row> rows =
+        
dataClient.readRows(Query.create(tableId)).stream().collect(Collectors.toList());
+
+    // we should still have one row with the same key
+    assertEquals(1, rows.size());
+    assertEquals("key-1", rows.get(0).getKey().toStringUtf8());
+    // we had two cells in col_a and deleted the older one. we should be left 
with the newer cell.
+    // check cell has correct value and timestamp
+    com.google.cloud.bigtable.data.v2.models.Row row = rows.get(0);
+    List<RowCell> cells = row.getCells(COLUMN_FAMILY_NAME_1, "col");
+    assertEquals(1, cells.size());
+    assertEquals("new-val", cells.get(0).getValue().toStringUtf8());
+    assertEquals(200_000_000, cells.get(0).getTimestamp());
+  }
+
+  @Test
+  public void testDeleteColumnFamily() {
+    RowMutation rowMutation =
+        RowMutation.create(tableId, "key-1")
+            .setCell(COLUMN_FAMILY_NAME_1, "col_a", "val")
+            .setCell(COLUMN_FAMILY_NAME_2, "col_b", "val");
+    dataClient.mutateRow(rowMutation);
+
+    List<Map<String, byte[]>> mutations = new ArrayList<>();
+    // mutation to delete a whole column family
+    mutations.add(
+        ImmutableMap.of(
+            "type", "DeleteFromFamily".getBytes(StandardCharsets.UTF_8),
+            "family_name", 
COLUMN_FAMILY_NAME_1.getBytes(StandardCharsets.UTF_8)));
+    Row mutationRow =
+        Row.withSchema(SCHEMA)
+            .withFieldValue("key", "key-1".getBytes(StandardCharsets.UTF_8))
+            .withFieldValue("mutations", mutations)
+            .build();
+
+    PCollectionRowTuple.of("input", 
p.apply(Create.of(Arrays.asList(mutationRow))))
+        .apply(writeTransform);
+    p.run().waitUntilFinish();
+
+    // get rows from table
+    List<com.google.cloud.bigtable.data.v2.models.Row> rows =
+        
dataClient.readRows(Query.create(tableId)).stream().collect(Collectors.toList());
+
+    // we should still have one row with the same key
+    assertEquals(1, rows.size());
+    assertEquals("key-1", rows.get(0).getKey().toStringUtf8());
+    // we had one cell in each of two column families. we deleted a column 
family, so should end up
+    // with
+    // one cell in the column family we didn't touch.
+    com.google.cloud.bigtable.data.v2.models.Row row = rows.get(0);
+    List<RowCell> cells = row.getCells();
+    assertEquals(1, cells.size());
+    assertEquals(COLUMN_FAMILY_NAME_2, cells.get(0).getFamily());
+  }
+
+  @Test
+  public void testDeleteRow() {
+    RowMutation rowMutation =
+        RowMutation.create(tableId, "key-1").setCell(COLUMN_FAMILY_NAME_1, 
"col", "val-1");
+    dataClient.mutateRow(rowMutation);
+    rowMutation =
+        RowMutation.create(tableId, "key-2").setCell(COLUMN_FAMILY_NAME_1, 
"col", "val-2");
+    dataClient.mutateRow(rowMutation);
+
+    List<Map<String, byte[]>> mutations = new ArrayList<>();
+    // mutation to delete a whole row
+    mutations.add(ImmutableMap.of("type", 
"DeleteFromRow".getBytes(StandardCharsets.UTF_8)));
+    Row mutationRow =
+        Row.withSchema(SCHEMA)
+            .withFieldValue("key", "key-1".getBytes(StandardCharsets.UTF_8))
+            .withFieldValue("mutations", mutations)
+            .build();
+
+    PCollectionRowTuple.of("input", 
p.apply(Create.of(Arrays.asList(mutationRow))))
+        .apply(writeTransform);
+    p.run().waitUntilFinish();
+
+    // get rows from table
+    List<com.google.cloud.bigtable.data.v2.models.Row> rows =
+        
dataClient.readRows(Query.create(tableId)).stream().collect(Collectors.toList());
+
+    // we created two rows then deleted one, so should end up with the row we 
didn't touch
+    assertEquals(1, rows.size());
+    assertEquals("key-2", rows.get(0).getKey().toStringUtf8());
+  }
+}
diff --git a/sdks/python/apache_beam/io/gcp/bigtableio.py 
b/sdks/python/apache_beam/io/gcp/bigtableio.py
index eedfb8f1c81..b2b52bd675c 100644
--- a/sdks/python/apache_beam/io/gcp/bigtableio.py
+++ b/sdks/python/apache_beam/io/gcp/bigtableio.py
@@ -38,6 +38,9 @@ those generated rows in the table.
 # pytype: skip-file
 
 import logging
+import struct
+from typing import Dict
+from typing import List
 
 import apache_beam as beam
 from apache_beam.internal.metrics.metric import ServiceCallMetric
@@ -48,6 +51,7 @@ from apache_beam.transforms import PTransform
 from apache_beam.transforms.display import DisplayDataItem
 from apache_beam.transforms.external import BeamJarExpansionService
 from apache_beam.transforms.external import SchemaAwareExternalTransform
+from apache_beam.typehints.row_type import RowTypeConstraint
 
 _LOGGER = logging.getLogger(__name__)
 
@@ -168,34 +172,119 @@ class _BigTableWriteFn(beam.DoFn):
 
 
 class WriteToBigTable(beam.PTransform):
-  """ A transform to write to the Bigtable Table.
+  """A transform that writes rows to a Bigtable table.
 
-  A PTransform that write a list of `DirectRow` into the Bigtable Table
+  Takes an input PCollection of `DirectRow` objects containing un-committed
+  mutations. For more information about this row object, visit
+  
https://cloud.google.com/python/docs/reference/bigtable/latest/row#class-googlecloudbigtablerowdirectrowrowkey-tablenone
 
+  If flag `use_cross_language` is set to true, this transform will use the
+  multi-language transforms framework to inject the Java native write transform
+  into the pipeline.
   """
-  def __init__(self, project_id=None, instance_id=None, table_id=None):
-    """ The PTransform to access the Bigtable Write connector
-    Args:
-      project_id(str): GCP Project of to write the Rows
-      instance_id(str): GCP Instance to write the Rows
-      table_id(str): GCP Table to write the `DirectRows`
+  URN = "beam:schematransform:org.apache.beam:bigtable_write:v1"
+
+  def __init__(
+      self,
+      project_id,
+      instance_id,
+      table_id,
+      use_cross_language=False,
+      expansion_service=None):
+    """Initialize an WriteToBigTable transform.
+
+    :param table_id:
+      The ID of the table to write to.
+    :param instance_id:
+      The ID of the instance where the table resides.
+    :param project_id:
+      The GCP project ID.
+    :param use_cross_language:
+      If set to True, will use the Java native transform via cross-language.
+    :param expansion_service:
+      The address of the expansion service in the case of using cross-language.
+      If no expansion service is provided, will attempt to run the default GCP
+      expansion service.
     """
     super().__init__()
-    self.beam_options = {
-        'project_id': project_id,
-        'instance_id': instance_id,
-        'table_id': table_id
-    }
+    self._table_id = table_id
+    self._instance_id = instance_id
+    self._project_id = project_id
+    self._use_cross_language = use_cross_language
+    if use_cross_language:
+      self._expansion_service = (
+          expansion_service or BeamJarExpansionService(
+              'sdks:java:io:google-cloud-platform:expansion-service:build'))
+      self.schematransform_config = (
+          SchemaAwareExternalTransform.discover_config(
+              self._expansion_service, self.URN))
 
-  def expand(self, pvalue):
-    beam_options = self.beam_options
-    return (
-        pvalue
-        | beam.ParDo(
-            _BigTableWriteFn(
-                beam_options['project_id'],
-                beam_options['instance_id'],
-                beam_options['table_id'])))
+  def expand(self, input):
+    if self._use_cross_language:
+      external_write = SchemaAwareExternalTransform(
+          identifier=self.schematransform_config.identifier,
+          expansion_service=self._expansion_service,
+          rearrange_based_on_discovery=True,
+          tableId=self._table_id,
+          instanceId=self._instance_id,
+          projectId=self._project_id)
+
+      return (
+          input
+          | beam.ParDo(self._DirectRowMutationsToBeamRow()).with_output_types(
+              RowTypeConstraint.from_fields(
+                  [("key", bytes), ("mutations", List[Dict[str, bytes]])]))
+          | external_write)
+    else:
+      return (
+          input
+          | beam.ParDo(
+              _BigTableWriteFn(
+                  self._project_id, self._instance_id, self._table_id)))
+
+  class _DirectRowMutationsToBeamRow(beam.DoFn):
+    def process(self, direct_row):
+      args = {"key": direct_row.row_key, "mutations": []}
+      # start accumulating mutations in a list
+      for mutation in direct_row._get_mutations():
+        if mutation.__contains__("set_cell"):
+          mutation_dict = {
+              "type": b'SetCell',
+              "family_name": mutation.set_cell.family_name.encode('utf-8'),
+              "column_qualifier": mutation.set_cell.column_qualifier,
+              "value": mutation.set_cell.value
+          }
+          micros = mutation.set_cell.timestamp_micros
+          if micros > -1:
+            mutation_dict['timestamp_micros'] = struct.pack('>q', micros)
+        elif mutation.__contains__("delete_from_column"):
+          mutation_dict = {
+              "type": b'DeleteFromColumn',
+              "family_name": mutation.delete_from_column.family_name.encode(
+                  'utf-8'),
+              "column_qualifier": mutation.delete_from_column.column_qualifier
+          }
+          time_range = mutation.delete_from_column.time_range
+          if time_range.start_timestamp_micros:
+            mutation_dict['start_timestamp_micros'] = struct.pack(
+                '>q', time_range.start_timestamp_micros)
+          if time_range.end_timestamp_micros:
+            mutation_dict['end_timestamp_micros'] = struct.pack(
+                '>q', time_range.end_timestamp_micros)
+        elif mutation.__contains__("delete_from_family"):
+          mutation_dict = {
+              "type": b'DeleteFromFamily',
+              "family_name": mutation.delete_from_family.family_name.encode(
+                  'utf-8')
+          }
+        elif mutation.__contains__("delete_from_row"):
+          mutation_dict = {"type": b'DeleteFromRow'}
+        else:
+          raise ValueError("Unexpected mutation")
+
+        args["mutations"].append(mutation_dict)
+
+      yield beam.Row(**args)
 
 
 class ReadFromBigtable(PTransform):
@@ -207,7 +296,7 @@ class ReadFromBigtable(PTransform):
   """
   URN = "beam:schematransform:org.apache.beam:bigtable_read:v1"
 
-  def __init__(self, table_id, instance_id, project_id, 
expansion_service=None):
+  def __init__(self, project_id, instance_id, table_id, 
expansion_service=None):
     """Initialize a ReadFromBigtable transform.
 
     :param table_id:
diff --git a/sdks/python/apache_beam/io/gcp/bigtableio_it_test.py 
b/sdks/python/apache_beam/io/gcp/bigtableio_it_test.py
new file mode 100644
index 00000000000..4c26da7012d
--- /dev/null
+++ b/sdks/python/apache_beam/io/gcp/bigtableio_it_test.py
@@ -0,0 +1,377 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Integration tests for BigTable service."""
+
+import logging
+import os
+import secrets
+import time
+import unittest
+# pytype: skip-file
+from datetime import datetime
+from datetime import timezone
+
+import pytest
+
+import apache_beam as beam
+from apache_beam.io.gcp import bigtableio
+from apache_beam.testing.test_pipeline import TestPipeline
+from apache_beam.testing.util import assert_that
+from apache_beam.testing.util import equal_to
+
+_LOGGER = logging.getLogger(__name__)
+
+# Protect against environments where bigtable library is not available.
+try:
+  from apitools.base.py.exceptions import HttpError
+  from google.cloud.bigtable import client
+  from google.cloud.bigtable.row_filters import TimestampRange
+  from google.cloud.bigtable.row import DirectRow, PartialRowData, Cell
+  from google.cloud.bigtable.table import Table
+  from google.cloud.bigtable_admin_v2.types import instance
+except ImportError as e:
+  client = None
+  HttpError = None
+
+
[email protected]_gcp_java_expansion_service
[email protected](
+    os.environ.get('EXPANSION_PORT'),
+    "EXPANSION_PORT environment var is not provided.")
[email protected](client is None, 'Bigtable dependencies are not installed')
+class TestReadFromBigTableIT(unittest.TestCase):
+  INSTANCE = "bt-read-tests"
+  TABLE_ID = "test-table"
+
+  def setUp(self):
+    self.test_pipeline = TestPipeline(is_integration_test=True)
+    self.args = self.test_pipeline.get_full_options_as_args()
+    self.project = self.test_pipeline.get_option('project')
+    self.expansion_service = ('localhost:%s' % 
os.environ.get('EXPANSION_PORT'))
+
+    instance_id = '%s-%s-%s' % (
+        self.INSTANCE, str(int(time.time())), secrets.token_hex(3))
+
+    self.client = client.Client(admin=True, project=self.project)
+    # create cluster and instance
+    self.instance = self.client.instance(
+        instance_id,
+        display_name=self.INSTANCE,
+        instance_type=instance.Instance.Type.DEVELOPMENT)
+    cluster = self.instance.cluster("test-cluster", "us-central1-a")
+    operation = self.instance.create(clusters=[cluster])
+    operation.result(timeout=500)
+    _LOGGER.info(
+        "Created instance [%s] in project [%s]",
+        self.instance.instance_id,
+        self.project)
+
+    # create table inside instance
+    self.table = self.instance.table(self.TABLE_ID)
+    self.table.create()
+    _LOGGER.info("Created table [%s]", self.table.table_id)
+
+  def tearDown(self):
+    try:
+      _LOGGER.info(
+          "Deleting table [%s] and instance [%s]",
+          self.table.table_id,
+          self.instance.instance_id)
+      self.table.delete()
+      self.instance.delete()
+    except HttpError:
+      _LOGGER.debug(
+          "Failed to clean up table [%s] and instance [%s]",
+          self.table.table_id,
+          self.instance.instance_id)
+
+  def add_rows(self, num_rows, num_families, num_columns_per_family):
+    cells = []
+    for i in range(1, num_rows + 1):
+      key = 'key-' + str(i)
+      row = DirectRow(key, self.table)
+      for j in range(num_families):
+        fam_name = 'test_col_fam_' + str(j)
+        # create the table's column families one time only
+        if i == 1:
+          col_fam = self.table.column_family(fam_name)
+          col_fam.create()
+        for k in range(1, num_columns_per_family + 1):
+          row.set_cell(fam_name, f'col-{j}-{k}', f'value-{i}-{j}-{k}')
+
+      # after all mutations on the row are done, commit to Bigtable
+      row.commit()
+      # read the same row back from Bigtable to get the expected data
+      # accumulate rows in `cells`
+      read_row: PartialRowData = self.table.read_row(key)
+      cells.append(read_row.cells)
+
+    return cells
+
+  def test_read_xlang(self):
+    # create rows and retrieve expected cells
+    expected_cells = self.add_rows(
+        num_rows=5, num_families=3, num_columns_per_family=4)
+
+    with beam.Pipeline(argv=self.args) as p:
+      cells = (
+          p
+          | bigtableio.ReadFromBigtable(
+              project_id=self.project,
+              instance_id=self.instance.instance_id,
+              table_id=self.table.table_id,
+              expansion_service=self.expansion_service)
+          | "Extract cells" >> beam.Map(lambda row: row._cells))
+
+      assert_that(cells, equal_to(expected_cells))
+
+
[email protected]_gcp_java_expansion_service
[email protected](
+    os.environ.get('EXPANSION_PORT'),
+    "EXPANSION_PORT environment var is not provided.")
[email protected](client is None, 'Bigtable dependencies are not installed')
+class TestWriteToBigtableXlangIT(unittest.TestCase):
+  # These are integration tests for the cross-language write transform.
+  INSTANCE = "bt-write-xlang"
+  TABLE_ID = "test-table"
+
+  @classmethod
+  def setUpClass(cls):
+    cls.test_pipeline = TestPipeline(is_integration_test=True)
+    cls.project = cls.test_pipeline.get_option('project')
+    cls.args = cls.test_pipeline.get_full_options_as_args()
+    cls.expansion_service = ('localhost:%s' % os.environ.get('EXPANSION_PORT'))
+
+    instance_id = '%s-%s-%s' % (
+        cls.INSTANCE, str(int(time.time())), secrets.token_hex(3))
+
+    cls.client = client.Client(admin=True, project=cls.project)
+    # create cluster and instance
+    cls.instance = cls.client.instance(
+        instance_id,
+        display_name=cls.INSTANCE,
+        instance_type=instance.Instance.Type.DEVELOPMENT)
+    cluster = cls.instance.cluster("test-cluster", "us-central1-a")
+    operation = cls.instance.create(clusters=[cluster])
+    operation.result(timeout=500)
+    _LOGGER.warning(
+        "Created instance [%s] in project [%s]",
+        cls.instance.instance_id,
+        cls.project)
+
+  def setUp(self):
+    # create table inside instance
+    self.table: Table = self.instance.table(
+        '%s-%s-%s' %
+        (self.TABLE_ID, str(int(time.time())), secrets.token_hex(3)))
+    self.table.create()
+    _LOGGER.info("Created table [%s]", self.table.table_id)
+
+  def tearDown(self):
+    try:
+      _LOGGER.info("Deleting table [%s]", self.table.table_id)
+      self.table.delete()
+    except HttpError:
+      _LOGGER.debug("Failed to clean up table [%s]", self.table.table_id)
+
+  @classmethod
+  def tearDownClass(cls):
+    try:
+      _LOGGER.info("Deleting instance [%s]", cls.instance.instance_id)
+      cls.instance.delete()
+    except HttpError:
+      _LOGGER.debug(
+          "Failed to clean up instance [%s]", cls.instance.instance_id)
+
+  def run_pipeline(self, rows):
+    with beam.Pipeline(argv=self.args) as p:
+      _ = (
+          p
+          | beam.Create(rows)
+          | bigtableio.WriteToBigTable(
+              project_id=self.project,
+              instance_id=self.instance.instance_id,
+              table_id=self.table.table_id,
+              use_cross_language=True,
+              expansion_service=self.expansion_service))
+
+  def test_set_mutation(self):
+    row1: DirectRow = DirectRow('key-1')
+    row2: DirectRow = DirectRow('key-2')
+    col_fam = self.table.column_family('col_fam')
+    col_fam.create()
+    # expected cells
+    row1_col1_cell = Cell(b'val1-1', 100_000_000)
+    row1_col2_cell = Cell(b'val1-2', 200_000_000)
+    row2_col1_cell = Cell(b'val2-1', 100_000_000)
+    row2_col2_cell = Cell(b'val2-2', 200_000_000)
+    # rows sent to write transform
+    row1.set_cell(
+        'col_fam', b'col-1', row1_col1_cell.value, row1_col1_cell.timestamp)
+    row1.set_cell(
+        'col_fam', b'col-2', row1_col2_cell.value, row1_col2_cell.timestamp)
+    row2.set_cell(
+        'col_fam', b'col-1', row2_col1_cell.value, row2_col1_cell.timestamp)
+    row2.set_cell(
+        'col_fam', b'col-2', row2_col2_cell.value, row2_col2_cell.timestamp)
+
+    self.run_pipeline([row1, row2])
+
+    # after write transform executes, get actual rows from table
+    actual_row1: PartialRowData = self.table.read_row('key-1')
+    actual_row2: PartialRowData = self.table.read_row('key-2')
+
+    # check actual rows match with expected rows (value and timestamp)
+    self.assertEqual(
+        row1_col1_cell, actual_row1.find_cells('col_fam', b'col-1')[0])
+    self.assertEqual(
+        row1_col2_cell, actual_row1.find_cells('col_fam', b'col-2')[0])
+    self.assertEqual(
+        row2_col1_cell, actual_row2.find_cells('col_fam', b'col-1')[0])
+    self.assertEqual(
+        row2_col2_cell, actual_row2.find_cells('col_fam', b'col-2')[0])
+
+  def test_delete_cells_mutation(self):
+    col_fam = self.table.column_family('col_fam')
+    col_fam.create()
+    # write a row with two columns to the table beforehand.
+    write_row: DirectRow = DirectRow('key-1', self.table)
+    write_row.set_cell('col_fam', b'col-1', b'val-1')
+    write_row.set_cell('col_fam', b'col-2', b'val-2')
+    write_row.commit()
+
+    # prepare a row that will delete cells in one of the columns.
+    delete_row: DirectRow = DirectRow('key-1')
+    delete_row.delete_cell('col_fam', b'col-1')
+
+    self.run_pipeline([delete_row])
+
+    # after transform executes, get actual row from table
+    actual_row: PartialRowData = self.table.read_row('key-1')
+
+    # we deleted all the cells in 'col-1', so this check should throw an error
+    with self.assertRaises(KeyError):
+      actual_row.find_cells('col_fam', b'col-1')
+
+    # check the cell in col-2 is still there
+    col2_cells = actual_row.find_cells('col_fam', b'col-2')
+    self.assertEqual(1, len(col2_cells))
+    self.assertEqual(b'val-2', col2_cells[0].value)
+
+  def test_delete_cells_with_timerange_mutation(self):
+    col_fam = self.table.column_family('col_fam')
+    col_fam.create()
+    # write two cells in a column to the table beforehand.
+    write_row: DirectRow = DirectRow('key-1', self.table)
+    write_row.set_cell(
+        'col_fam',
+        b'col',
+        b'val',
+        datetime.fromtimestamp(100_000_000, tz=timezone.utc))
+    write_row.commit()
+    write_row.set_cell(
+        'col_fam',
+        b'col',
+        b'new-val',
+        datetime.fromtimestamp(200_000_000, tz=timezone.utc))
+    write_row.commit()
+
+    # prepare a row that will delete cells within a timestamp range.
+    delete_row: DirectRow = DirectRow('key-1')
+    delete_row.delete_cell(
+        'col_fam',
+        b'col',
+        time_range=TimestampRange(
+            start=datetime.fromtimestamp(99_999_999, tz=timezone.utc),
+            end=datetime.fromtimestamp(100_000_001, tz=timezone.utc)))
+
+    self.run_pipeline([delete_row])
+
+    # after transform executes, get actual row from table
+    actual_row: PartialRowData = self.table.read_row('key-1')
+
+    # we deleted one cell within the timestamp range.
+    # check the other (newer) cell still exists.
+    cells = actual_row.find_cells('col_fam', b'col')
+    self.assertEqual(1, len(cells))
+    self.assertEqual(b'new-val', cells[0].value)
+    self.assertEqual(
+        datetime.fromtimestamp(200_000_000, tz=timezone.utc),
+        cells[0].timestamp)
+
+  def test_delete_column_family_mutation(self):
+    # create two column families
+    col_fam = self.table.column_family('col_fam-1')
+    col_fam.create()
+    col_fam = self.table.column_family('col_fam-2')
+    col_fam.create()
+    # write a row with values in both column families to the table beforehand.
+    write_row: DirectRow = DirectRow('key-1', self.table)
+    write_row.set_cell('col_fam-1', b'col', b'val')
+    write_row.set_cell('col_fam-2', b'col', b'val')
+    write_row.commit()
+
+    # prepare a row that will delete a column family from the row
+    delete_row: DirectRow = DirectRow('key-1')
+    delete_row.delete_cells('col_fam-1', delete_row.ALL_COLUMNS)
+
+    self.run_pipeline([delete_row])
+
+    # after transform executes, get actual row from table
+    actual_row: PartialRowData = self.table.read_row('key-1')
+
+    # we deleted column family 'col_fam-1', so this check should throw an error
+    with self.assertRaises(KeyError):
+      actual_row.find_cells('col_fam-1', b'col')
+
+    # check we have one column family left with the correct cell value
+    self.assertEqual(1, len(actual_row.cells))
+    self.assertEqual(b'val', actual_row.cell_value('col_fam-2', b'col'))
+
+  def test_delete_row_mutation(self):
+    write_row1: DirectRow = DirectRow('key-1', self.table)
+    write_row2: DirectRow = DirectRow('key-2', self.table)
+    col_fam = self.table.column_family('col_fam')
+    col_fam.create()
+    # write a couple of rows to the table beforehand
+    write_row1.set_cell('col_fam', b'col', b'val-1')
+    write_row1.commit()
+    write_row2.set_cell('col_fam', b'col', b'val-2')
+    write_row2.commit()
+
+    # prepare a row that will delete itself
+    delete_row: DirectRow = DirectRow('key-1')
+    delete_row.delete()
+
+    self.run_pipeline([delete_row])
+
+    # after write transform executes, get actual rows from table
+    actual_row1: PartialRowData = self.table.read_row('key-1')
+    actual_row2: PartialRowData = self.table.read_row('key-2')
+
+    # we deleted row with key 'key-1', check it doesn't exist anymore
+    # the Bigtable API doesn't throw an error here, just returns a None value
+    self.assertEqual(None, actual_row1)
+    # check row 2 exists with the correct cell value in col
+    self.assertEqual(b'val-2', actual_row2.cell_value('col_fam', b'col'))
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  unittest.main()
diff --git a/sdks/python/apache_beam/io/gcp/bigtableio_test.py 
b/sdks/python/apache_beam/io/gcp/bigtableio_test.py
index 012e0478ada..f97c9bcfbd6 100644
--- a/sdks/python/apache_beam/io/gcp/bigtableio_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigtableio_test.py
@@ -17,18 +17,15 @@
 
 """Unit tests for BigTable service."""
 
-# pytype: skip-file
 import logging
-import os
-import secrets
 import string
-import time
 import unittest
 import uuid
+# pytype: skip-file
 from datetime import datetime
+from datetime import timezone
 from random import choice
 
-import pytest
 from mock import MagicMock
 from mock import patch
 
@@ -38,113 +35,20 @@ from apache_beam.io.gcp import bigtableio
 from apache_beam.io.gcp import resource_identifiers
 from apache_beam.metrics import monitoring_infos
 from apache_beam.metrics.execution import MetricsEnvironment
-from apache_beam.testing.test_pipeline import TestPipeline
-from apache_beam.testing.util import assert_that
-from apache_beam.testing.util import equal_to
 
 _LOGGER = logging.getLogger(__name__)
 
 # Protect against environments where bigtable library is not available.
 try:
-  from apitools.base.py.exceptions import HttpError
   from google.cloud.bigtable import client
+  from google.cloud.bigtable.row_filters import TimestampRange
   from google.cloud.bigtable.instance import Instance
   from google.cloud.bigtable.row import DirectRow, PartialRowData, Cell
   from google.cloud.bigtable.table import Table
-  from google.cloud.bigtable_admin_v2.types import instance
   from google.rpc.code_pb2 import OK, ALREADY_EXISTS
   from google.rpc.status_pb2 import Status
 except ImportError as e:
   client = None
-  HttpError = None
-
-
[email protected]_gcp_java_expansion_service
[email protected](
-    os.environ.get('EXPANSION_PORT'),
-    "EXPANSION_PORT environment var is not provided.")
[email protected](client is None, 'Bigtable dependencies are not installed')
-class TestReadFromBigTable(unittest.TestCase):
-  INSTANCE = "bt-read-tests"
-  TABLE_ID = "test-table"
-
-  def setUp(self):
-    self.test_pipeline = TestPipeline(is_integration_test=True)
-    self.args = self.test_pipeline.get_full_options_as_args()
-    self.project = self.test_pipeline.get_option('project')
-
-    instance_id = '%s-%s-%s' % (
-        self.INSTANCE, str(int(time.time())), secrets.token_hex(3))
-
-    self.client = client.Client(admin=True, project=self.project)
-    # create cluster and instance
-    self.instance = self.client.instance(
-        instance_id,
-        display_name=self.INSTANCE,
-        instance_type=instance.Instance.Type.DEVELOPMENT)
-    cluster = self.instance.cluster("test-cluster", "us-central1-a")
-    operation = self.instance.create(clusters=[cluster])
-    operation.result(timeout=500)
-    _LOGGER.info(
-        "Created instance [%s] in project [%s]",
-        self.instance.instance_id,
-        self.project)
-
-    # create table inside instance
-    self.table = self.instance.table(self.TABLE_ID)
-    self.table.create()
-    _LOGGER.info("Created table [%s]", self.table.table_id)
-
-  def tearDown(self):
-    try:
-      _LOGGER.info(
-          "Deleting table [%s] and instance [%s]",
-          self.table.table_id,
-          self.instance.instance_id)
-      self.table.delete()
-      self.instance.delete()
-    except HttpError:
-      _LOGGER.debug(
-          "Failed to clean up table [%s] and instance [%s]",
-          self.table.table_id,
-          self.instance.instance_id)
-
-  def add_rows(self, num_rows, num_families, num_columns_per_family):
-    cells = []
-    for i in range(1, num_rows + 1):
-      key = 'key-' + str(i)
-      row = DirectRow(key, self.table)
-      for j in range(num_families):
-        fam_name = 'test_col_fam_' + str(j)
-        # create the table's column families one time only
-        if i == 1:
-          col_fam = self.table.column_family(fam_name)
-          col_fam.create()
-        for k in range(1, num_columns_per_family + 1):
-          row.set_cell(fam_name, f'col-{j}-{k}', f'value-{i}-{j}-{k}')
-
-      # after all mutations on the row are done, commit to Bigtable
-      row.commit()
-      # read the same row back from Bigtable to get the expected data
-      # accumulate rows in `cells`
-      read_row: PartialRowData = self.table.read_row(key)
-      cells.append(read_row.cells)
-
-    return cells
-
-  def test_read_xlang(self):
-    # create rows and retrieve expected cells
-    expected_cells = self.add_rows(
-        num_rows=5, num_families=3, num_columns_per_family=4)
-
-    with beam.Pipeline(argv=self.args) as p:
-      cells = (
-          p
-          | bigtableio.ReadFromBigtable(
-              self.table.table_id, self.instance.instance_id, self.project)
-          | "Extract cells" >> beam.Map(lambda row: row._cells))
-
-      assert_that(cells, equal_to(expected_cells))
 
 
 @unittest.skipIf(client is None, 'Bigtable dependencies are not installed')
@@ -205,6 +109,151 @@ class TestBeamRowToPartialRowData(unittest.TestCase):
                      bigtable_row.find_cells('family_2', b'column_qualifier'))
 
 
[email protected](client is None, 'Bigtable dependencies are not installed')
+class TestBigtableDirectRowToBeamRow(unittest.TestCase):
+  doFn = bigtableio.WriteToBigTable._DirectRowMutationsToBeamRow()
+
+  def test_set_cell(self):
+    # create some set cell mutations
+    direct_row: DirectRow = DirectRow('key-1')
+    direct_row.set_cell(
+        'col_fam',
+        b'col',
+        b'a',
+        datetime.fromtimestamp(100_000).replace(tzinfo=timezone.utc))
+    direct_row.set_cell(
+        'col_fam',
+        b'other-col',
+        b'b',
+        datetime.fromtimestamp(200_000).replace(tzinfo=timezone.utc))
+    direct_row.set_cell(
+        'other_col_fam',
+        b'col',
+        b'c',
+        datetime.fromtimestamp(300_000).replace(tzinfo=timezone.utc))
+
+    # get equivalent beam row
+    beam_row = next(self.doFn.process(direct_row))
+
+    # sort both lists of mutations for convenience
+    beam_row_mutations = sorted(beam_row.mutations, key=lambda m: m['value'])
+    bt_row_mutations = sorted(
+        direct_row._get_mutations(), key=lambda m: m.set_cell.value)
+    self.assertEqual(beam_row.key, direct_row.row_key)
+    self.assertEqual(len(beam_row_mutations), len(bt_row_mutations))
+
+    # check that the values in each beam mutation is equal to the original
+    # Bigtable direct row mutations
+    for i in range(len(beam_row_mutations)):
+      beam_mutation = beam_row_mutations[i]
+      bt_mutation = bt_row_mutations[i].set_cell
+
+      self.assertEqual(beam_mutation['type'], b'SetCell')
+      self.assertEqual(
+          beam_mutation['family_name'].decode(), bt_mutation.family_name)
+      self.assertEqual(
+          beam_mutation['column_qualifier'], bt_mutation.column_qualifier)
+      self.assertEqual(beam_mutation['value'], bt_mutation.value)
+      self.assertEqual(
+          int.from_bytes(beam_mutation['timestamp_micros'], 'big'),
+          bt_mutation.timestamp_micros)
+
+  def test_delete_cells(self):
+    # create some delete cell mutations. one with a timestamp range
+    direct_row: DirectRow = DirectRow('key-1')
+    direct_row.delete_cell('col_fam', b'col-1')
+    direct_row.delete_cell(
+        'other_col_fam',
+        b'col-2',
+        time_range=TimestampRange(
+            start=datetime.fromtimestamp(10_000_000, tz=timezone.utc)))
+    direct_row.delete_cells(
+        'another_col_fam', [b'col-3', b'col-4', b'col-5'],
+        time_range=TimestampRange(
+            start=datetime.fromtimestamp(50_000_000, tz=timezone.utc),
+            end=datetime.fromtimestamp(100_000_000, tz=timezone.utc)))
+
+    # get equivalent beam row
+    beam_row = next(self.doFn.process(direct_row))
+
+    # sort both lists of mutations for convenience
+    beam_row_mutations = sorted(
+        beam_row.mutations, key=lambda m: m['column_qualifier'])
+    bt_row_mutations = sorted(
+        direct_row._get_mutations(),
+        key=lambda m: m.delete_from_column.column_qualifier)
+    self.assertEqual(beam_row.key, direct_row.row_key)
+    self.assertEqual(len(beam_row_mutations), len(bt_row_mutations))
+
+    # check that the values in each beam mutation is equal to the original
+    # Bigtable direct row mutations
+    for i in range(len(beam_row_mutations)):
+      beam_mutation = beam_row_mutations[i]
+      bt_mutation = bt_row_mutations[i].delete_from_column
+      print(bt_mutation)
+
+      self.assertEqual(beam_mutation['type'], b'DeleteFromColumn')
+      self.assertEqual(
+          beam_mutation['family_name'].decode(), bt_mutation.family_name)
+      self.assertEqual(
+          beam_mutation['column_qualifier'], bt_mutation.column_qualifier)
+
+      # check we set a timestamp range only when appropriate
+      if bt_mutation.time_range.start_timestamp_micros:
+        self.assertEqual(
+            int.from_bytes(beam_mutation['start_timestamp_micros'], 'big'),
+            bt_mutation.time_range.start_timestamp_micros)
+      else:
+        self.assertTrue('start_timestamp_micros' not in beam_mutation)
+
+      if bt_mutation.time_range.end_timestamp_micros:
+        self.assertEqual(
+            int.from_bytes(beam_mutation['end_timestamp_micros'], 'big'),
+            bt_mutation.time_range.end_timestamp_micros)
+      else:
+        self.assertTrue('end_timestamp_micros' not in beam_mutation)
+
+  def test_delete_column_family(self):
+    # create mutation to delete column family
+    direct_row: DirectRow = DirectRow('key-1')
+    direct_row.delete_cells('col_fam-1', direct_row.ALL_COLUMNS)
+    direct_row.delete_cells('col_fam-2', direct_row.ALL_COLUMNS)
+
+    # get equivalent beam row
+    beam_row = next(self.doFn.process(direct_row))
+
+    # sort both lists of mutations for convenience
+    beam_row_mutations = sorted(
+        beam_row.mutations, key=lambda m: m['family_name'])
+    bt_row_mutations = sorted(
+        direct_row._get_mutations(),
+        key=lambda m: m.delete_from_column.family_name)
+    self.assertEqual(beam_row.key, direct_row.row_key)
+    self.assertEqual(len(beam_row_mutations), len(bt_row_mutations))
+
+    # check that the values in each beam mutation is equal to the original
+    # Bigtable direct row mutations
+    for i in range(len(beam_row_mutations)):
+      beam_mutation = beam_row_mutations[i]
+      bt_mutation = bt_row_mutations[i].delete_from_family
+
+      self.assertEqual(beam_mutation['type'], b'DeleteFromFamily')
+      self.assertEqual(
+          beam_mutation['family_name'].decode(), bt_mutation.family_name)
+
+  def test_delete_row(self):
+    # create mutation to delete the Bigtable row
+    direct_row: DirectRow = DirectRow('key-1')
+    direct_row.delete()
+
+    # get equivalent beam row
+    beam_row = next(self.doFn.process(direct_row))
+    self.assertEqual(beam_row.key, direct_row.row_key)
+
+    beam_mutation = beam_row.mutations[0]
+    self.assertEqual(beam_mutation['type'], b'DeleteFromRow')
+
+
 @unittest.skipIf(client is None, 'Bigtable dependencies are not installed')
 class TestWriteBigTable(unittest.TestCase):
   TABLE_PREFIX = "python-test"

Reply via email to