This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0a91d139dea Golang SpannerIO Implementation (#23285)
0a91d139dea is described below

commit 0a91d139dea4276dc46176c4cdcdfce210fc50c4
Author: Chris Gavin <[email protected]>
AuthorDate: Wed Nov 23 09:56:40 2022 +1100

    Golang SpannerIO Implementation (#23285)
---
 sdks/go.mod                                   |   4 +-
 sdks/go.sum                                   |   6 +
 sdks/go/pkg/beam/io/spannerio/spanner.go      | 241 +++++++++++++++++++++++++
 sdks/go/pkg/beam/io/spannerio/spanner_test.go | 249 ++++++++++++++++++++++++++
 4 files changed, 499 insertions(+), 1 deletion(-)

diff --git a/sdks/go.mod b/sdks/go.mod
index 061250fc0dd..b9d0728c711 100644
--- a/sdks/go.mod
+++ b/sdks/go.mod
@@ -61,6 +61,8 @@ require (
        gopkg.in/yaml.v2 v2.4.0
 )
 
+require cloud.google.com/go/spanner v1.36.0
+
 require (
        cloud.google.com/go/bigtable v1.18.0
        github.com/tetratelabs/wazero v1.0.0-pre.3
@@ -92,7 +94,7 @@ require (
        github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.8 // indirect
        github.com/aws/aws-sdk-go-v2/service/sts v1.17.4 // indirect
        github.com/cenkalti/backoff/v4 v4.1.3 // indirect
-       github.com/census-instrumentation/opencensus-proto v0.2.1 // indirect
+       github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect
        github.com/cespare/xxhash/v2 v2.1.2 // indirect
        github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4 // indirect
        github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1 // indirect
diff --git a/sdks/go.sum b/sdks/go.sum
index acee0826903..2c1d96beab1 100644
--- a/sdks/go.sum
+++ b/sdks/go.sum
@@ -68,6 +68,8 @@ cloud.google.com/go/pubsub v1.2.0/go.mod 
h1:jhfEVHT8odbXTkndysNHCcx0awwzvfOlguIA
 cloud.google.com/go/pubsub v1.3.1/go.mod 
h1:i+ucay31+CNRpDW4Lu78I4xXG+O1r/MAHgjpRVR+TSU=
 cloud.google.com/go/pubsub v1.26.0 
h1:Y/HcMxVXgkUV2pYeLMUkclMg0ue6U0jVyI5xEARQ4zA=
 cloud.google.com/go/pubsub v1.26.0/go.mod 
h1:QgBH3U/jdJy/ftjPhTkyXNj543Tin1pRYcdcPRnFIRI=
+cloud.google.com/go/spanner v1.36.0 
h1:MYc3fKJlZZCpZymoKBqPR23Hxd1CFhH+zsQPMzeM1xI=
+cloud.google.com/go/spanner v1.36.0/go.mod 
h1:RKVKnqXxTMDuBPAsjxohvcSTH6qiRB6E0oMljFIKPr0=
 cloud.google.com/go/storage v1.0.0/go.mod 
h1:IhtSnM/ZTZV8YYJWCY8RULGVqBDmpoyjwiyrjsg+URw=
 cloud.google.com/go/storage v1.5.0/go.mod 
h1:tpKbwo567HUNpVclU5sGELwQWBDZ8gh0ZeosJ0Rtdos=
 cloud.google.com/go/storage v1.6.0/go.mod 
h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohlUTyfDhBk=
@@ -214,6 +216,10 @@ github.com/cenkalti/backoff/v4 v4.1.3 
h1:cFAlzYUlVYDysBEH2T5hyJZMh3+5+WCBvSnK6Q8
 github.com/cenkalti/backoff/v4 v4.1.3/go.mod 
h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw=
 github.com/census-instrumentation/opencensus-proto v0.2.1 
h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk=
 github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod 
h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
+github.com/census-instrumentation/opencensus-proto v0.3.0 
h1:t/LhUZLVitR1Ow2YOnduCsavhwFUklBMoGVYUCqmCqk=
+github.com/census-instrumentation/opencensus-proto v0.3.0/go.mod 
h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
+github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod 
h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA=
+github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054/go.mod 
h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA=
 github.com/cespare/xxhash v1.1.0/go.mod 
h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
 github.com/cespare/xxhash/v2 v2.1.1/go.mod 
h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cespare/xxhash/v2 v2.1.2 
h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE=
diff --git a/sdks/go/pkg/beam/io/spannerio/spanner.go 
b/sdks/go/pkg/beam/io/spannerio/spanner.go
new file mode 100644
index 00000000000..4ad9158a603
--- /dev/null
+++ b/sdks/go/pkg/beam/io/spannerio/spanner.go
@@ -0,0 +1,241 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package spannerio provides an API for reading and writing resouces to
+// Google Spanner datastores.
+package spannerio
+
+import (
+       "context"
+       "fmt"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
+       "reflect"
+       "strings"
+
+       "cloud.google.com/go/spanner"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/register"
+       "google.golang.org/api/iterator"
+)
+
+func init() {
+       register.DoFn3x1[context.Context, []byte, func(beam.X), 
error]((*queryFn)(nil))
+       register.Emitter1[beam.X]()
+       register.DoFn3x1[context.Context, int, func(*beam.X) bool, 
error]((*writeFn)(nil))
+       register.Iter1[beam.X]()
+}
+
+func columnsFromStruct(t reflect.Type) []string {
+       var columns []string
+
+       for i := 0; i < t.NumField(); i++ {
+               columns = append(columns, t.Field(i).Tag.Get("spanner"))
+       }
+
+       return columns
+}
+
+// Read reads all rows from the given table. The table must have a schema
+// compatible with the given type, t, and Read returns a PCollection<t>. If the
+// table has more rows than t, then Read is implicitly a projection.
+func Read(s beam.Scope, database string, table string, t reflect.Type) 
beam.PCollection {
+       s = s.Scope("spanner.Read")
+
+       // TODO(herohde) 7/13/2017: using * is probably too inefficient. We 
could infer
+       // a focused query from the type.
+
+       cols := strings.Join(columnsFromStruct(t), ",")
+
+       return query(s, database, nil, fmt.Sprintf("SELECT %v from %v", cols, 
table), t)
+}
+
+// queryOptions represents additional options for executing a query.
+type queryOptions struct {
+}
+
+// Query executes a query. The output must have a schema compatible with the 
given
+// type, t. It returns a PCollection<t>.
+// Note: Query will be executed on a single worker. Consider performance of 
query
+// and if downstream splitting is required add beam.Reshuffle.
+func Query(s beam.Scope, database string, q string, t reflect.Type, options 
...func(*queryOptions) error) beam.PCollection {
+       s = s.Scope("spanner.Query")
+       return query(s, database, nil, q, t, options...)
+}
+
+// Entrypoint for testing with spanner client injectable.
+func query(s beam.Scope, database string, client *spanner.Client, query 
string, t reflect.Type, options ...func(*queryOptions) error) beam.PCollection {
+       queryOptions := queryOptions{}
+       for _, opt := range options {
+               if err := opt(&queryOptions); err != nil {
+                       panic(err)
+               }
+       }
+
+       imp := beam.Impulse(s)
+       return beam.ParDo(s, &queryFn{Database: database, Query: query, Type: 
beam.EncodedType{T: t}, Options: queryOptions, Client: client}, imp, 
beam.TypeDefinition{Var: beam.XType, T: t})
+}
+
+type queryFn struct {
+       Database string           `json:"database"` // Database is the spanner 
connection string
+       Query    string           `json:"query"`    // Table is the table 
identifier.
+       Type     beam.EncodedType `json:"type"`     // Type is the encoded 
schema type.
+       Options  queryOptions     `json:"options"`  // Options specifies 
additional query execution options.
+
+       Client *spanner.Client // Spanner Client
+}
+
+func (f *queryFn) Setup() {
+       if f.Client == nil {
+               client, err := spanner.NewClient(context.Background(), 
f.Database)
+               if err != nil {
+                       panic("Failed to initialise Spanner client: " + 
err.Error())
+               }
+
+               f.Client = client
+       }
+}
+
+func (f *queryFn) Teardown() {
+       if f.Client != nil {
+               f.Client.Close()
+       }
+}
+
+func (f *queryFn) ProcessElement(ctx context.Context, _ []byte, emit 
func(beam.X)) error {
+       // todo: Use Batch Read
+
+       stmt := spanner.Statement{SQL: f.Query}
+       it := f.Client.Single().Query(ctx, stmt)
+       defer it.Stop()
+
+       for {
+               val := reflect.New(f.Type.T).Interface() // val : *T
+               row, err := it.Next()
+               if err != nil {
+                       if err == iterator.Done {
+                               break
+                       }
+                       return err
+               }
+
+               if err := row.ToStruct(val); err != nil {
+                       return err
+               }
+
+               emit(reflect.ValueOf(val).Elem().Interface()) // emit(*val)
+       }
+       return nil
+}
+
+type writeOptions struct {
+       BatchSize int
+}
+
+// UseBatchSize explicitly sets the batch size per transaction for writes
+func UseBatchSize(batchSize int) func(qo *writeOptions) error {
+       return func(qo *writeOptions) error {
+               qo.BatchSize = batchSize
+               return nil
+       }
+}
+
+// Write writes the elements of the given PCollection<T> to spanner. T is 
required
+// to be the schema type.
+// Note: Writes occur against a single worker machine.
+func Write(s beam.Scope, database string, table string, col beam.PCollection, 
options ...func(*writeOptions) error) {
+       if typex.IsCoGBK(col.Type()) || typex.IsKV(col.Type()) {
+               panic("Unsupported collection type - only normal structs 
supported for writing.")
+       }
+
+       s = s.Scope("spanner.Write")
+
+       write(s, database, nil, table, col, options...)
+}
+
+// Entrypoint for testing with spanner client injectable.
+func write(s beam.Scope, database string, client *spanner.Client, table 
string, col beam.PCollection, options ...func(*writeOptions) error) {
+       writeOptions := writeOptions{
+               BatchSize: 1000, // default
+       }
+       for _, opt := range options {
+               if err := opt(&writeOptions); err != nil {
+                       panic(err)
+               }
+       }
+
+       t := col.Type().Type()
+
+       pre := beam.AddFixedKey(s, col)
+       post := beam.GroupByKey(s, pre)
+       beam.ParDo0(s, &writeFn{Database: database, Table: table, Type: 
beam.EncodedType{T: t}, Options: writeOptions, Client: client}, post)
+}
+
+type writeFn struct {
+       Database string           `json:"database"` // Fully qualified 
identifier
+       Table    string           `json:"table"`    // The table to write to
+       Type     beam.EncodedType `json:"type"`     // Type is the encoded 
schema type.
+       Options  writeOptions     `json:"options"`  // Spanner write options
+
+       Client *spanner.Client // Spanner Client
+}
+
+func (f *writeFn) Setup() {
+       if f.Client == nil {
+               client, err := spanner.NewClient(context.Background(), 
f.Database)
+               if err != nil {
+                       panic("Failed to initialise Spanner client: " + 
err.Error())
+               }
+
+               f.Client = client
+       }
+}
+
+func (f *writeFn) Teardown() {
+       if f.Client != nil {
+               f.Client.Close()
+       }
+}
+
+func (f *writeFn) ProcessElement(ctx context.Context, _ int, iter 
func(*beam.X) bool) error {
+       var mutations []*spanner.Mutation
+
+       var val beam.X
+       for iter(&val) {
+               mutation, err := spanner.InsertOrUpdateStruct(f.Table, val)
+               if err != nil {
+                       return err
+               }
+
+               mutations = append(mutations, mutation)
+
+               if len(mutations)+1 > f.Options.BatchSize {
+                       _, err := f.Client.Apply(ctx, mutations)
+                       if err != nil {
+                               return err
+                       }
+
+                       mutations = nil
+               }
+       }
+
+       if mutations != nil {
+               _, err := f.Client.Apply(ctx, mutations)
+               if err != nil {
+                       return err
+               }
+       }
+
+       return nil
+}
diff --git a/sdks/go/pkg/beam/io/spannerio/spanner_test.go 
b/sdks/go/pkg/beam/io/spannerio/spanner_test.go
new file mode 100644
index 00000000000..f83c4d55f9e
--- /dev/null
+++ b/sdks/go/pkg/beam/io/spannerio/spanner_test.go
@@ -0,0 +1,249 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package spannerio
+
+import (
+       "cloud.google.com/go/spanner"
+       "cloud.google.com/go/spanner/spannertest"
+       "cloud.google.com/go/spanner/spansql"
+       "context"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
+       "google.golang.org/api/iterator"
+       "google.golang.org/api/option"
+       "google.golang.org/grpc"
+       "reflect"
+       "testing"
+)
+
+type TestDto struct {
+       One string `spanner:"One"`
+       Two int64  `spanner:"Two"`
+}
+
+func TestColumnsFromStructReturnsColumns(t *testing.T) {
+       // arrange
+       // act
+       cols := columnsFromStruct(reflect.TypeOf(TestDto{}))
+
+       // assert
+       if len(cols) != 2 {
+               t.Fatalf("got %v columns, expected 2", len(cols))
+       }
+}
+
+func TestRead(t *testing.T) {
+       testCases := []struct {
+               name          string
+               database      string
+               rows          []TestDto
+               expectedError bool
+       }{
+               {
+                       name:     "Successfully read 4 rows",
+                       database: 
"projects/fake-proj/instances/fake-instance/databases/fake-db-4-rows",
+                       rows: []TestDto{
+                               {
+                                       One: "one",
+                                       Two: 1,
+                               },
+                               {
+                                       One: "one",
+                                       Two: 2,
+                               },
+                               {
+                                       One: "one",
+                                       Two: 3,
+                               },
+                               {
+                                       One: "one",
+                                       Two: 4,
+                               },
+                       },
+               },
+               {
+                       name:     "Successfully read 1 rows",
+                       database: 
"projects/fake-proj/instances/fake-instance/databases/fake-db-1-rows",
+                       rows: []TestDto{
+                               {
+                                       One: "one",
+                                       Two: 1,
+                               },
+                       },
+               },
+       }
+
+       for _, testCase := range testCases {
+               t.Run(testCase.name, func(t *testing.T) {
+                       srv, srvCleanup := newServer(t)
+                       defer srvCleanup()
+
+                       client, cleanup := createFakeClient(t, srv.Addr, 
testCase.database)
+                       defer cleanup()
+
+                       ddl, err := spansql.ParseDDL("",
+                               `CREATE TABLE Test (
+                                       One STRING(20),
+                                       Two INT64,
+                               ) PRIMARY KEY (Two)`)
+                       if err != nil {
+                               t.Fatalf("Unable to create DDL statement for 
spanner test: %v", err)
+                       }
+
+                       err = srv.UpdateDDL(ddl)
+                       if err != nil {
+                               t.Fatalf("Unable to run DDL into spanner db: 
%v", err)
+                       }
+
+                       var mutations []*spanner.Mutation
+                       for _, m := range testCase.rows {
+                               mutation, err := spanner.InsertStruct("Test", m)
+                               if err != nil {
+                                       t.Fatalf("Unable to create mutation to 
insert struct: %v", err)
+                               }
+
+                               mutations = append(mutations, mutation)
+                       }
+
+                       _, err = client.Apply(context.Background(), mutations)
+                       if err != nil {
+                               t.Fatalf("Applying mutations: %v", err)
+                       }
+
+                       p := beam.NewPipeline()
+                       s := p.Root()
+                       rows := query(s, "", client, "SELECT * from Test", 
reflect.TypeOf(TestDto{}))
+
+                       passert.Count(s, rows, "", len(testCase.rows))
+                       ptest.RunAndValidate(t, p)
+               })
+       }
+}
+
+func TestWrite(t *testing.T) {
+       testCases := []struct {
+               name          string
+               database      string
+               rows          []TestDto
+               expectedError bool
+       }{
+               {
+                       name:     "Successfully write 4 rows",
+                       database: 
"projects/fake-proj/instances/fake-instance/databases/fake-db-4-rows",
+                       rows: []TestDto{
+                               {
+                                       One: "one",
+                                       Two: 1,
+                               },
+                               {
+                                       One: "one",
+                                       Two: 2,
+                               },
+                               {
+                                       One: "one",
+                                       Two: 3,
+                               },
+                               {
+                                       One: "one",
+                                       Two: 4,
+                               },
+                       },
+               },
+       }
+
+       for _, testCase := range testCases {
+               t.Run(testCase.name, func(t *testing.T) {
+                       srv, srvCleanup := newServer(t)
+                       defer srvCleanup()
+
+                       client, cleanup := createFakeClient(t, srv.Addr, 
testCase.database)
+                       defer cleanup()
+
+                       ddl, err := spansql.ParseDDL("",
+                               `CREATE TABLE Test (
+                                       One STRING(20),
+                                       Two INT64,
+                               ) PRIMARY KEY (Two)`)
+                       if err != nil {
+                               t.Fatalf("Unable to create DDL statement for 
spanner test: %v", err)
+                       }
+
+                       err = srv.UpdateDDL(ddl)
+                       if err != nil {
+                               t.Fatalf("Unable to run DDL into spanner db: 
%v", err)
+                       }
+
+                       p, s, col := ptest.CreateList(testCase.rows)
+
+                       write(s, "", client, "Test", col)
+
+                       ptest.RunAndValidate(t, p)
+
+                       verifyClient, verifyClientCleanup := 
createFakeClient(t, srv.Addr, testCase.database)
+                       defer verifyClientCleanup()
+
+                       stmt := spanner.Statement{SQL: "SELECT * FROM Test"}
+                       it := verifyClient.Single().Query(context.Background(), 
stmt)
+                       defer it.Stop()
+
+                       var count int
+                       for {
+                               _, err := it.Next()
+                               if err != nil {
+                                       if err == iterator.Done {
+                                               break
+                                       }
+                               }
+                               count++
+                       }
+
+                       if count != len(testCase.rows) {
+                               t.Fatalf("Got incorrect number of rows from 
spanner write, got '%v', expected '%v'", count, len(testCase.rows))
+                       }
+               })
+       }
+}
+
+func newServer(t *testing.T) (*spannertest.Server, func()) {
+       srv, err := spannertest.NewServer("localhost:0")
+       if err != nil {
+               t.Fatalf("Starting in-memory fake spanner: %v", err)
+       }
+
+       return srv, func() {
+               srv.Close()
+       }
+}
+
+func createFakeClient(t *testing.T, address string, database string) 
(*spanner.Client, func()) {
+       ctx := context.Background()
+
+       conn, err := grpc.DialContext(ctx, address, grpc.WithInsecure())
+       if err != nil {
+               t.Fatalf("Dialling in-memory fake spanner: %v", err)
+       }
+
+       client, err := spanner.NewClient(ctx, database, 
option.WithGRPCConn(conn))
+       if err != nil {
+               t.Fatalf("Connecting to in-memory fake spanner: %v", err)
+       }
+
+       return client, func() {
+               client.Close()
+               conn.Close()
+       }
+}

Reply via email to