This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 974d4e3  [BEAM-11982] Java Spanner - Implement IO Request Count 
metrics (#15493)
974d4e3 is described below

commit 974d4e334664c685add5bfabfc33c36b7f913c54
Author: Benjamin Gonzalez <[email protected]>
AuthorDate: Fri Sep 24 23:15:35 2021 -0500

    [BEAM-11982] Java Spanner - Implement IO Request Count metrics (#15493)
---
 model/pipeline/src/main/proto/metrics.proto        |   4 +
 .../core/metrics/GcpResourceIdentifiers.java       |   9 +
 .../core/metrics/MonitoringInfoConstants.java      |  12 ++
 .../beam/sdk/io/gcp/spanner/BatchSpannerRead.java  |  30 ++++
 .../beam/sdk/io/gcp/spanner/ReadOperation.java     |   8 +
 .../apache/beam/sdk/io/gcp/spanner/SpannerIO.java  |  36 ++++
 .../beam/sdk/io/gcp/spanner/SpannerIOReadTest.java | 182 +++++++++++++++++----
 .../sdk/io/gcp/spanner/SpannerIOWriteTest.java     |  66 ++++++++
 8 files changed, 317 insertions(+), 30 deletions(-)

diff --git a/model/pipeline/src/main/proto/metrics.proto 
b/model/pipeline/src/main/proto/metrics.proto
index 8f819b6..913b2d0 100644
--- a/model/pipeline/src/main/proto/metrics.proto
+++ b/model/pipeline/src/main/proto/metrics.proto
@@ -420,6 +420,10 @@ message MonitoringInfo {
     BIGTABLE_PROJECT_ID = 20 [(label_props) = { name: "BIGTABLE_PROJECT_ID"}];
     INSTANCE_ID = 21 [(label_props) = { name: "INSTANCE_ID"}];
     TABLE_ID = 22 [(label_props) = { name: "TABLE_ID"}];
+    SPANNER_PROJECT_ID = 23 [(label_props) = { name: "SPANNER_PROJECT_ID"}];
+    SPANNER_DATABASE_ID = 24 [(label_props) = { name: "SPANNER_DATABASE_ID"}];
+    SPANNER_INSTANCE_ID = 25 [(label_props) = { name: "SPANNER_INSTANCE_ID" }];
+    SPANNER_QUERY_NAME = 26 [(label_props) = { name: "SPANNER_QUERY_NAME" }];
   }
 
   // A set of key and value labels which define the scope of the metric. For
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java
 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java
index 4c388bf..336f08d 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GcpResourceIdentifiers.java
@@ -51,4 +51,13 @@ public class GcpResourceIdentifiers {
     return String.format(
         "//bigtable.googleapis.com/projects/%s/namespaces/%s", projectId, 
namespace);
   }
+
+  public static String spannerTable(String projectId, String databaseId, 
String tableId) {
+    return String.format(
+        "//spanner.googleapis.com/projects/%s/topics/%s/tables/%s", projectId, 
databaseId, tableId);
+  }
+
+  public static String spannerQuery(String projectId, String queryName) {
+    return String.format("//spanner.googleapis.com/projects/%s/queries/%s", 
projectId, queryName);
+  }
 }
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
index c792719..03496fe 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MonitoringInfoConstants.java
@@ -85,6 +85,10 @@ public final class MonitoringInfoConstants {
     public static final String TABLE_ID = "TABLE_ID";
     public static final String GCS_BUCKET = "GCS_BUCKET";
     public static final String GCS_PROJECT_ID = "GCS_PROJECT_ID";
+    public static final String SPANNER_PROJECT_ID = "SPANNER_PROJECT_ID";
+    public static final String SPANNER_DATABASE_ID = "SPANNER_DATABASE_ID";
+    public static final String SPANNER_INSTANCE_ID = "SPANNER_INSTANCE_ID";
+    public static final String SPANNER_QUERY_NAME = "SPANNER_QUERY_NAME";
 
     static {
       // Note: One benefit of defining these strings above, instead of pulling 
them in from
@@ -120,6 +124,14 @@ public final class MonitoringInfoConstants {
       
checkArgument(TABLE_ID.equals(extractLabel(MonitoringInfoLabels.TABLE_ID)));
       
checkArgument(GCS_BUCKET.equals(extractLabel(MonitoringInfoLabels.GCS_BUCKET)));
       
checkArgument(GCS_PROJECT_ID.equals(extractLabel(MonitoringInfoLabels.GCS_PROJECT_ID)));
+      checkArgument(
+          
SPANNER_PROJECT_ID.equals(extractLabel(MonitoringInfoLabels.SPANNER_PROJECT_ID)));
+      checkArgument(
+          
SPANNER_DATABASE_ID.equals(extractLabel(MonitoringInfoLabels.SPANNER_DATABASE_ID)));
+      checkArgument(
+          
SPANNER_INSTANCE_ID.equals(extractLabel(MonitoringInfoLabels.SPANNER_INSTANCE_ID)));
+      checkArgument(
+          
SPANNER_QUERY_NAME.equals(extractLabel(MonitoringInfoLabels.SPANNER_QUERY_NAME)));
     }
   }
 
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java
index fc24c8f..5393c7d 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/BatchSpannerRead.java
@@ -21,9 +21,14 @@ import com.google.auto.value.AutoValue;
 import com.google.cloud.spanner.BatchReadOnlyTransaction;
 import com.google.cloud.spanner.Partition;
 import com.google.cloud.spanner.ResultSet;
+import com.google.cloud.spanner.SpannerException;
 import com.google.cloud.spanner.Struct;
 import com.google.cloud.spanner.TimestampBound;
+import java.util.HashMap;
 import java.util.List;
+import org.apache.beam.runners.core.metrics.GcpResourceIdentifiers;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.ServiceCallMetric;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -158,18 +163,43 @@ abstract class BatchSpannerRead
 
     @ProcessElement
     public void processElement(ProcessContext c) throws Exception {
+      ServiceCallMetric serviceCallMetric =
+          createServiceCallMetric(
+              this.config.getProjectId().toString(),
+              this.config.getDatabaseId().toString(),
+              this.config.getInstanceId().toString());
       Transaction tx = c.sideInput(txView);
 
       BatchReadOnlyTransaction batchTx =
           
spannerAccessor.getBatchClient().batchReadOnlyTransaction(tx.transactionId());
 
+      serviceCallMetric.call("ok");
       Partition p = c.element();
       try (ResultSet resultSet = batchTx.execute(p)) {
         while (resultSet.next()) {
           Struct s = resultSet.getCurrentRowAsStruct();
           c.output(s);
         }
+      } catch (SpannerException e) {
+        
serviceCallMetric.call(e.getErrorCode().getGrpcStatusCode().toString());
       }
     }
+
+    private ServiceCallMetric createServiceCallMetric(
+        String projectId, String databaseId, String tableId) {
+      HashMap<String, String> baseLabels = new HashMap<>();
+      baseLabels.put(MonitoringInfoConstants.Labels.PTRANSFORM, "");
+      baseLabels.put(MonitoringInfoConstants.Labels.SERVICE, "Spanner");
+      baseLabels.put(MonitoringInfoConstants.Labels.METHOD, "Read");
+      baseLabels.put(
+          MonitoringInfoConstants.Labels.RESOURCE,
+          GcpResourceIdentifiers.spannerTable(projectId, databaseId, tableId));
+      baseLabels.put(MonitoringInfoConstants.Labels.SPANNER_PROJECT_ID, 
projectId);
+      baseLabels.put(MonitoringInfoConstants.Labels.SPANNER_DATABASE_ID, 
databaseId);
+      baseLabels.put(MonitoringInfoConstants.Labels.SPANNER_INSTANCE_ID, 
tableId);
+      ServiceCallMetric serviceCallMetric =
+          new 
ServiceCallMetric(MonitoringInfoConstants.Urns.API_REQUEST_COUNT, baseLabels);
+      return serviceCallMetric;
+    }
   }
 }
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java
index 1066115..0c9c42d 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/ReadOperation.java
@@ -42,6 +42,8 @@ public abstract class ReadOperation implements Serializable {
 
   public abstract @Nullable Statement getQuery();
 
+  public abstract @Nullable String getQueryName();
+
   public abstract @Nullable String getTable();
 
   public abstract @Nullable String getIndex();
@@ -57,6 +59,8 @@ public abstract class ReadOperation implements Serializable {
 
     abstract Builder setQuery(Statement statement);
 
+    abstract Builder setQueryName(String queryName);
+
     abstract Builder setTable(String table);
 
     abstract Builder setIndex(String index);
@@ -92,6 +96,10 @@ public abstract class ReadOperation implements Serializable {
     return withQuery(Statement.of(sql));
   }
 
+  public ReadOperation withQueryName(String queryName) {
+    return toBuilder().setQueryName(queryName).build();
+  }
+
   public ReadOperation withKeySet(KeySet keySet) {
     return toBuilder().setKeySet(keySet).build();
   }
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
index 07ff216..1244b90 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
@@ -43,9 +43,13 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.List;
 import java.util.OptionalInt;
 import java.util.concurrent.TimeUnit;
+import org.apache.beam.runners.core.metrics.GcpResourceIdentifiers;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.ServiceCallMetric;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
 import org.apache.beam.sdk.coders.SerializableCoder;
@@ -661,6 +665,10 @@ public class SpannerIO {
       return withQuery(Statement.of(sql));
     }
 
+    public Read withQueryName(String queryName) {
+      return withReadOperation(getReadOperation().withQueryName(queryName));
+    }
+
     public Read withKeySet(KeySet keySet) {
       return withReadOperation(getReadOperation().withKeySet(keySet));
     }
@@ -1638,10 +1646,18 @@ public class SpannerIO {
     private void spannerWriteWithRetryIfSchemaChange(Iterable<Mutation> batch)
         throws SpannerException {
       for (int retry = 1; ; retry++) {
+        ServiceCallMetric serviceCallMetric =
+            createServiceCallMetric(
+                this.spannerConfig.getProjectId().toString(),
+                this.spannerConfig.getDatabaseId().toString(),
+                this.spannerConfig.getInstanceId().toString(),
+                "Write");
         try {
           spannerAccessor.getDatabaseClient().writeAtLeastOnce(batch);
+          serviceCallMetric.call("ok");
           return;
         } catch (AbortedException e) {
+          
serviceCallMetric.call(e.getErrorCode().getGrpcStatusCode().toString());
           if (retry >= ABORTED_RETRY_ATTEMPTS) {
             throw e;
           }
@@ -1649,10 +1665,30 @@ public class SpannerIO {
             continue;
           }
           throw e;
+        } catch (SpannerException e) {
+          
serviceCallMetric.call(e.getErrorCode().getGrpcStatusCode().toString());
+          throw e;
         }
       }
     }
 
+    private ServiceCallMetric createServiceCallMetric(
+        String projectId, String databaseId, String tableId, String method) {
+      HashMap<String, String> baseLabels = new HashMap<>();
+      baseLabels.put(MonitoringInfoConstants.Labels.PTRANSFORM, "");
+      baseLabels.put(MonitoringInfoConstants.Labels.SERVICE, "Spanner");
+      baseLabels.put(MonitoringInfoConstants.Labels.METHOD, method);
+      baseLabels.put(
+          MonitoringInfoConstants.Labels.RESOURCE,
+          GcpResourceIdentifiers.spannerTable(projectId, databaseId, tableId));
+      baseLabels.put(MonitoringInfoConstants.Labels.SPANNER_PROJECT_ID, 
projectId);
+      baseLabels.put(MonitoringInfoConstants.Labels.SPANNER_DATABASE_ID, 
databaseId);
+      baseLabels.put(MonitoringInfoConstants.Labels.SPANNER_INSTANCE_ID, 
tableId);
+      ServiceCallMetric serviceCallMetric =
+          new 
ServiceCallMetric(MonitoringInfoConstants.Urns.API_REQUEST_COUNT, baseLabels);
+      return serviceCallMetric;
+    }
+
     /** Write the Mutations to Spanner, handling DEADLINE_EXCEEDED with 
backoff/retries. */
     private void writeMutations(Iterable<Mutation> mutations) throws 
SpannerException, IOException {
       BackOff backoff = bundleWriteBackoff.backoff();
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java
index 5977c2e..0610596 100644
--- 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOReadTest.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.io.gcp.spanner;
 
+import static org.junit.Assert.assertEquals;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.when;
@@ -24,12 +25,14 @@ import static org.mockito.Mockito.when;
 import com.google.cloud.Timestamp;
 import com.google.cloud.spanner.BatchReadOnlyTransaction;
 import com.google.cloud.spanner.BatchTransactionId;
+import com.google.cloud.spanner.ErrorCode;
 import com.google.cloud.spanner.FakeBatchTransactionId;
 import com.google.cloud.spanner.FakePartitionFactory;
 import com.google.cloud.spanner.KeySet;
 import com.google.cloud.spanner.Partition;
 import com.google.cloud.spanner.PartitionOptions;
 import com.google.cloud.spanner.ResultSets;
+import com.google.cloud.spanner.SpannerExceptionFactory;
 import com.google.cloud.spanner.Statement;
 import com.google.cloud.spanner.Struct;
 import com.google.cloud.spanner.TimestampBound;
@@ -38,12 +41,19 @@ import com.google.cloud.spanner.Value;
 import com.google.protobuf.ByteString;
 import java.io.Serializable;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.List;
+import org.apache.beam.runners.core.metrics.GcpResourceIdentifiers;
+import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
+import org.checkerframework.checker.nullness.qual.Nullable;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -83,6 +93,9 @@ public class SpannerIOReadTest implements Serializable {
   public void setUp() throws Exception {
     serviceFactory = new FakeServiceFactory();
     mockBatchTx = Mockito.mock(BatchReadOnlyTransaction.class);
+    // Setup the ProcessWideContainer for testing metrics are set.
+    MetricsContainerImpl container = new MetricsContainerImpl(null);
+    MetricsEnvironment.setProcessWideContainer(container);
   }
 
   @Test
@@ -90,12 +103,7 @@ public class SpannerIOReadTest implements Serializable {
     Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345);
     TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp);
 
-    SpannerConfig spannerConfig =
-        SpannerConfig.create()
-            .withProjectId("test")
-            .withInstanceId("123")
-            .withDatabaseId("aaa")
-            .withServiceFactory(serviceFactory);
+    SpannerConfig spannerConfig = getSpannerConfig();
 
     PCollection<Struct> one =
         pipeline.apply(
@@ -129,17 +137,20 @@ public class SpannerIOReadTest implements Serializable {
     pipeline.run();
   }
 
+  private SpannerConfig getSpannerConfig() {
+    return SpannerConfig.create()
+        .withProjectId("test")
+        .withInstanceId("123")
+        .withDatabaseId("aaa")
+        .withServiceFactory(serviceFactory);
+  }
+
   @Test
   public void runRead() throws Exception {
     Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345);
     TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp);
 
-    SpannerConfig spannerConfig =
-        SpannerConfig.create()
-            .withProjectId("test")
-            .withInstanceId("123")
-            .withDatabaseId("aaa")
-            .withServiceFactory(serviceFactory);
+    SpannerConfig spannerConfig = getSpannerConfig();
 
     PCollection<Struct> one =
         pipeline.apply(
@@ -179,16 +190,137 @@ public class SpannerIOReadTest implements Serializable {
   }
 
   @Test
+  public void testQueryMetrics() throws Exception {
+    Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345);
+    TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp);
+
+    SpannerConfig spannerConfig = getSpannerConfig();
+
+    PCollection<Struct> one =
+        pipeline.apply(
+            "read q",
+            SpannerIO.read()
+                .withSpannerConfig(spannerConfig)
+                .withQuery("SELECT * FROM users")
+                .withQueryName("queryName")
+                .withTimestampBound(timestampBound));
+
+    FakeBatchTransactionId id = new FakeBatchTransactionId("runQueryTest");
+    when(mockBatchTx.getBatchTransactionId()).thenReturn(id);
+
+    
when(serviceFactory.mockBatchClient().batchReadOnlyTransaction(timestampBound))
+        .thenReturn(mockBatchTx);
+    
when(serviceFactory.mockBatchClient().batchReadOnlyTransaction(any(BatchTransactionId.class)))
+        .thenReturn(mockBatchTx);
+
+    Partition fakePartition =
+        
FakePartitionFactory.createFakeQueryPartition(ByteString.copyFromUtf8("one"));
+
+    when(mockBatchTx.partitionQuery(
+            any(PartitionOptions.class), eq(Statement.of("SELECT * FROM 
users"))))
+        .thenReturn(Arrays.asList(fakePartition, fakePartition));
+    when(mockBatchTx.execute(any(Partition.class)))
+        .thenThrow(
+            SpannerExceptionFactory.newSpannerException(
+                ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 1"))
+        .thenThrow(
+            SpannerExceptionFactory.newSpannerException(
+                ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 2"))
+        .thenReturn(
+            ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(0, 2)),
+            ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(2, 6)));
+
+    pipeline.run();
+    verifyMetricWasSet("test", "aaa", "123", "deadline_exceeded", null, 2);
+    verifyMetricWasSet("test", "aaa", "123", "ok", null, 2);
+  }
+
+  @Test
+  public void testReadMetrics() throws Exception {
+    Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345);
+    TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp);
+
+    SpannerConfig spannerConfig = getSpannerConfig();
+
+    PCollection<Struct> one =
+        pipeline.apply(
+            "read q",
+            SpannerIO.read()
+                .withSpannerConfig(spannerConfig)
+                .withTable("users")
+                .withColumns("id", "name")
+                .withTimestampBound(timestampBound));
+
+    FakeBatchTransactionId id = new FakeBatchTransactionId("runReadTest");
+    when(mockBatchTx.getBatchTransactionId()).thenReturn(id);
+
+    
when(serviceFactory.mockBatchClient().batchReadOnlyTransaction(timestampBound))
+        .thenReturn(mockBatchTx);
+    
when(serviceFactory.mockBatchClient().batchReadOnlyTransaction(any(BatchTransactionId.class)))
+        .thenReturn(mockBatchTx);
+
+    Partition fakePartition =
+        
FakePartitionFactory.createFakeReadPartition(ByteString.copyFromUtf8("one"));
+
+    when(mockBatchTx.partitionRead(
+            any(PartitionOptions.class),
+            eq("users"),
+            eq(KeySet.all()),
+            eq(Arrays.asList("id", "name"))))
+        .thenReturn(Arrays.asList(fakePartition, fakePartition, 
fakePartition));
+    when(mockBatchTx.execute(any(Partition.class)))
+        .thenThrow(
+            SpannerExceptionFactory.newSpannerException(
+                ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 1"))
+        .thenThrow(
+            SpannerExceptionFactory.newSpannerException(
+                ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 2"))
+        .thenReturn(
+            ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(0, 2)),
+            ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(2, 4)),
+            ResultSets.forRows(FAKE_TYPE, FAKE_ROWS.subList(4, 6)));
+
+    pipeline.run();
+    verifyMetricWasSet("test", "aaa", "123", "deadline_exceeded", null, 2);
+    verifyMetricWasSet("test", "aaa", "123", "ok", null, 3);
+  }
+
+  private void verifyMetricWasSet(
+      String projectId,
+      String databaseId,
+      String tableId,
+      String status,
+      @Nullable String queryName,
+      long count) {
+    // Verify the metric was reported.
+    HashMap<String, String> labels = new HashMap<>();
+    labels.put(MonitoringInfoConstants.Labels.PTRANSFORM, "");
+    labels.put(MonitoringInfoConstants.Labels.SERVICE, "Spanner");
+    labels.put(MonitoringInfoConstants.Labels.METHOD, "Read");
+    labels.put(
+        MonitoringInfoConstants.Labels.RESOURCE,
+        GcpResourceIdentifiers.spannerTable(projectId, databaseId, tableId));
+    labels.put(MonitoringInfoConstants.Labels.SPANNER_PROJECT_ID, projectId);
+    labels.put(MonitoringInfoConstants.Labels.SPANNER_DATABASE_ID, databaseId);
+    labels.put(MonitoringInfoConstants.Labels.SPANNER_INSTANCE_ID, tableId);
+    if (queryName != null) {
+      labels.put(MonitoringInfoConstants.Labels.SPANNER_QUERY_NAME, queryName);
+    }
+    labels.put(MonitoringInfoConstants.Labels.STATUS, status);
+
+    MonitoringInfoMetricName name =
+        
MonitoringInfoMetricName.named(MonitoringInfoConstants.Urns.API_REQUEST_COUNT, 
labels);
+    MetricsContainerImpl container =
+        (MetricsContainerImpl) MetricsEnvironment.getProcessWideContainer();
+    assertEquals(count, (long) container.getCounter(name).getCumulative());
+  }
+
+  @Test
   public void runReadUsingIndex() throws Exception {
     Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345);
     TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp);
 
-    SpannerConfig spannerConfig =
-        SpannerConfig.create()
-            .withProjectId("test")
-            .withInstanceId("123")
-            .withDatabaseId("aaa")
-            .withServiceFactory(serviceFactory);
+    SpannerConfig spannerConfig = getSpannerConfig();
 
     PCollection<Struct> one =
         pipeline.apply(
@@ -237,12 +369,7 @@ public class SpannerIOReadTest implements Serializable {
     Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345);
     TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp);
 
-    SpannerConfig spannerConfig =
-        SpannerConfig.create()
-            .withProjectId("test")
-            .withInstanceId("123")
-            .withDatabaseId("aaa")
-            .withServiceFactory(serviceFactory);
+    SpannerConfig spannerConfig = getSpannerConfig();
 
     PCollection<Struct> one =
         pipeline.apply(
@@ -281,12 +408,7 @@ public class SpannerIOReadTest implements Serializable {
     Timestamp timestamp = Timestamp.ofTimeMicroseconds(12345);
     TimestampBound timestampBound = TimestampBound.ofReadTimestamp(timestamp);
 
-    SpannerConfig spannerConfig =
-        SpannerConfig.create()
-            .withProjectId("test")
-            .withInstanceId("123")
-            .withDatabaseId("aaa")
-            .withServiceFactory(serviceFactory);
+    SpannerConfig spannerConfig = getSpannerConfig();
 
     PCollectionView<Transaction> tx =
         pipeline.apply(
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
index 58ce514..d0239fc 100644
--- 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
@@ -50,13 +50,19 @@ import com.google.cloud.spanner.Struct;
 import com.google.cloud.spanner.Type;
 import java.io.Serializable;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.List;
+import org.apache.beam.runners.core.metrics.GcpResourceIdentifiers;
+import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
+import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
+import org.apache.beam.runners.core.metrics.MonitoringInfoMetricName;
 import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
 import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.BatchableMutationFilterFn;
 import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.FailureMode;
 import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.GatherSortCreateBatchesFn;
 import org.apache.beam.sdk.io.gcp.spanner.SpannerIO.WriteToSpannerFn;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
@@ -120,6 +126,10 @@ public class SpannerIOWriteTest implements Serializable {
     // Simplest schema: a table with int64 key
     preparePkMetadata(tx, Arrays.asList(pkMetadata("tEsT", "key", "ASC")));
     prepareColumnMetadata(tx, Arrays.asList(columnMetadata("tEsT", "key", 
"INT64", CELLS_PER_KEY)));
+
+    // Setup the ProcessWideContainer for testing metrics are set.
+    MetricsContainerImpl container = new MetricsContainerImpl(null);
+    MetricsEnvironment.setProcessWideContainer(container);
   }
 
   private SpannerSchema getSchema() {
@@ -408,6 +418,62 @@ public class SpannerIOWriteTest implements Serializable {
   }
 
   @Test
+  public void testSpannerWriteMetricIsSet() {
+    Mutation mutation = m(2L);
+    PCollection<Mutation> mutations = pipeline.apply(Create.of(mutation));
+
+    // respond with 2 error codes and a success.
+    when(serviceFactory.mockDatabaseClient().writeAtLeastOnce(any()))
+        .thenThrow(
+            SpannerExceptionFactory.newSpannerException(
+                ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 1"))
+        .thenThrow(
+            SpannerExceptionFactory.newSpannerException(
+                ErrorCode.DEADLINE_EXCEEDED, "Simulated Timeout 2"))
+        .thenReturn(Timestamp.now());
+
+    mutations.apply(
+        SpannerIO.write()
+            .withProjectId("test-project")
+            .withInstanceId("test-instance")
+            .withDatabaseId("test-database")
+            .withFailureMode(FailureMode.FAIL_FAST)
+            .withServiceFactory(serviceFactory));
+    pipeline.run();
+
+    verifyMetricWasSet(
+        "test-project", "test-database", "test-instance", "Write", 
"deadline_exceeded", 2);
+    verifyMetricWasSet("test-project", "test-database", "test-instance", 
"Write", "ok", 1);
+  }
+
+  private void verifyMetricWasSet(
+      String projectId,
+      String databaseId,
+      String tableId,
+      String method,
+      String status,
+      long count) {
+    // Verify the metric was reported.
+    HashMap<String, String> labels = new HashMap<>();
+    labels.put(MonitoringInfoConstants.Labels.PTRANSFORM, "");
+    labels.put(MonitoringInfoConstants.Labels.SERVICE, "Spanner");
+    labels.put(MonitoringInfoConstants.Labels.METHOD, method);
+    labels.put(
+        MonitoringInfoConstants.Labels.RESOURCE,
+        GcpResourceIdentifiers.spannerTable(projectId, databaseId, tableId));
+    labels.put(MonitoringInfoConstants.Labels.SPANNER_PROJECT_ID, projectId);
+    labels.put(MonitoringInfoConstants.Labels.SPANNER_DATABASE_ID, databaseId);
+    labels.put(MonitoringInfoConstants.Labels.SPANNER_INSTANCE_ID, tableId);
+    labels.put(MonitoringInfoConstants.Labels.STATUS, status);
+
+    MonitoringInfoMetricName name =
+        
MonitoringInfoMetricName.named(MonitoringInfoConstants.Urns.API_REQUEST_COUNT, 
labels);
+    MetricsContainerImpl container =
+        (MetricsContainerImpl) MetricsEnvironment.getProcessWideContainer();
+    assertEquals(count, (long) container.getCounter(name).getCumulative());
+  }
+
+  @Test
   public void deadlineExceededRetries() throws InterruptedException {
     List<Mutation> mutationList = Arrays.asList(m((long) 1));
 

Reply via email to