[ 
https://issues.apache.org/jira/browse/BEAM-5729?focusedWorklogId=166218&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-166218
 ]

ASF GitHub Bot logged work on BEAM-5729:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 15/Nov/18 00:25
            Start Date: 15/Nov/18 00:25
    Worklog Time Spent: 10m 
      Work Description: chamikaramj closed pull request #6676: [BEAM-5729] 
added database/sql reader/writer
URL: https://github.com/apache/beam/pull/6676
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/sdks/go/pkg/beam/io/databaseio/database.go 
b/sdks/go/pkg/beam/io/databaseio/database.go
new file mode 100644
index 00000000000..e9de9cbbbe9
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/database.go
@@ -0,0 +1,219 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also: 
https://golang.org/pkg/database/sql/
+package databaseio
+
+import (
+       "context"
+       "database/sql"
+       "fmt"
+       "github.com/apache/beam/sdks/go/pkg/beam"
+       "github.com/apache/beam/sdks/go/pkg/beam/log"
+       "reflect"
+       "strings"
+)
+
+func init() {
+       beam.RegisterType(reflect.TypeOf((*queryFn)(nil)).Elem())
+       beam.RegisterType(reflect.TypeOf((*writeFn)(nil)).Elem())
+}
+
+// writeSizeLimit is the maximum number of rows allowed to a write.
+const writeRowLimit = 1000
+
+// Read reads all rows from the given table. The table must have a schema
+// compatible with the given type, t, and Read returns a PCollection<t>. If the
+// table has more rows than t, then Read is implicitly a projection.
+func Read(s beam.Scope, driver, dsn, table string, t reflect.Type) 
beam.PCollection {
+       s = s.Scope(driver + ".Read")
+       return query(s, driver, dsn, fmt.Sprintf("SELECT * from %v", table), t)
+}
+
+// Query executes a query. The output must have a schema compatible with the 
given
+// type, t. It returns a PCollection<t>.
+func Query(s beam.Scope, driver, dsn, q string, t reflect.Type) 
beam.PCollection {
+       s = s.Scope(driver + ".Query")
+       return query(s, driver, dsn, q, t)
+}
+
+func query(s beam.Scope, driver, dsn, query string, t reflect.Type) 
beam.PCollection {
+       imp := beam.Impulse(s)
+       return beam.ParDo(s, &queryFn{Driver: driver, Dsn: dsn, Query: query, 
Type: beam.EncodedType{T: t}}, imp, beam.TypeDefinition{Var: beam.XType, T: t})
+}
+
+type queryFn struct {
+       // Project is the project
+       Driver string `json:"driver"`
+       // Project is the project
+       Dsn string `json:"dsn"`
+       // Table is the table identifier.
+       Query string `json:"query"`
+       // Type is the encoded schema type.
+       Type beam.EncodedType `json:"type"`
+}
+
+func (f *queryFn) ProcessElement(ctx context.Context, _ []byte, emit 
func(beam.X)) error {
+       //TODO move DB Open and Close to Setup and Teardown methods or 
StartBundle and FinishBundle
+       db, err := sql.Open(f.Driver, f.Dsn)
+       if err != nil {
+               return fmt.Errorf("failed to open database: %v, %v", f.Driver, 
err)
+       }
+       defer db.Close()
+       statement, err := db.PrepareContext(ctx, f.Query)
+       if err != nil {
+               return fmt.Errorf("failed to prepare query: %v, %v", f.Query, 
err)
+       }
+       defer statement.Close()
+       rows, err := statement.QueryContext(ctx)
+       if err != nil {
+               return fmt.Errorf("failed to run query: %v, %v", f.Query, err)
+       }
+       defer rows.Close()
+       var mapper rowMapper
+       var columns []string
+       for rows.Next() {
+               reflectRow := reflect.New(f.Type.T)
+               row := reflectRow.Interface() // row : *T
+               if mapper == nil {
+                       columns, err = rows.Columns()
+                       if err != nil {
+                               return err
+                       }
+                       columnsTypes, _ := rows.ColumnTypes()
+                       if mapper, err = newQueryMapper(columns, columnsTypes, 
f.Type.T); err != nil {
+                               return fmt.Errorf("failed to create rowValues 
mapper: %v", err)
+                       }
+               }
+               rowValues, err := mapper(reflectRow)
+               if err != nil {
+                       return err
+               }
+               err = rows.Scan(rowValues...)
+               if err != nil {
+                       return fmt.Errorf("failed to scan %v, %v", f.Query, err)
+               }
+               if loader, ok := row.(MapLoader); ok {
+                       asDereferenceSlice(rowValues)
+                       loader.LoadMap(asMap(columns, rowValues))
+               } else if loader, ok := row.(SliceLoader); ok {
+                       asDereferenceSlice(rowValues)
+                       loader.LoadSlice(rowValues)
+               }
+               emit(reflect.ValueOf(row).Elem().Interface()) // emit(*row)
+       }
+       return nil
+}
+
+// Write writes the elements of the given PCollection<T> to database, if 
columns left empty all table columns are used to insert into, otherwise selected
+func Write(s beam.Scope, driver, dsn, table string, columns []string, col 
beam.PCollection) {
+       t := col.Type().Type()
+       s = s.Scope(driver + ".Write")
+       pre := beam.AddFixedKey(s, col)
+       post := beam.GroupByKey(s, pre)
+       beam.ParDo0(s, &writeFn{Driver: driver, Dsn: dsn, Table: table, 
Columns: columns, BatchSize: writeRowLimit, Type: beam.EncodedType{T: t}}, post)
+}
+
+// WriteWithBatchSize writes the elements of the given PCollection<T> to 
database with custom batch size. Batch size control number of elements in the 
batch INSERT statement.
+func WriteWithBatchSize(s beam.Scope, batchSize int, driver, dsn, table 
string, columns []string, col beam.PCollection) {
+       t := col.Type().Type()
+       s = s.Scope(driver + ".Write")
+       pre := beam.AddFixedKey(s, col)
+       post := beam.GroupByKey(s, pre)
+       beam.ParDo0(s, &writeFn{Driver: driver, Dsn: dsn, Table: table, 
Columns: columns, BatchSize: batchSize, Type: beam.EncodedType{T: t}}, post)
+}
+
+type writeFn struct {
+       // Project is the project
+       Driver string `json:"driver"`
+       // Project is the project
+       Dsn string `json:"dsn"`
+       // Table is the table identifier.
+       Table string `json:"table"`
+       // Columns to inserts, if empty then all columns
+       Columns []string `json:"columns"`
+       //BatchSize size
+       BatchSize int `json:"batchSize"`
+       // Type is the encoded schema type.
+       Type beam.EncodedType `json:"type"`
+}
+
+func (f *writeFn) ProcessElement(ctx context.Context, _ int, iter 
func(*beam.X) bool) error {
+       //TODO move DB Open and Close to Setup and Teardown methods or 
StartBundle and FinishBundle
+       db, err := sql.Open(f.Driver, f.Dsn)
+       if err != nil {
+               return fmt.Errorf("failed to open database: %v, %v", f.Driver, 
err)
+       }
+       defer db.Close()
+       projection := "*"
+       if len(f.Columns) > 0 {
+               projection = strings.Join(f.Columns, ",")
+       }
+       dql := fmt.Sprintf("SELECT %v FROM  %v WHERE 1 = 0", projection, 
f.Table)
+       query, err := db.Prepare(dql)
+       if err != nil {
+               return fmt.Errorf("failed to prepare query: %v, %v", f.Table, 
err)
+       }
+       defer query.Close()
+       rows, err := query.Query()
+       if err != nil {
+               return fmt.Errorf("failed to query: %v, %v", f.Table, err)
+       }
+       columns, err := rows.Columns()
+       if err != nil {
+               return fmt.Errorf("failed to discover column: %v, %v", f.Table, 
err)
+       }
+       //TODO move to Setup methods
+       mapper, err := newWriterRowMapper(columns, f.Type.T)
+       if err != nil {
+               return fmt.Errorf("failed to create row mapper: %v", err)
+       }
+       writer, err := newWriter(f.BatchSize, f.Table, columns)
+       if err != nil {
+               return err
+       }
+       var val beam.X
+       for iter(&val) {
+               var row []interface{}
+               var data map[string]interface{}
+               if writer, ok := val.(Writer); ok {
+                       if data, err = writer.SaveData(); err == nil {
+                               row = make([]interface{}, len(columns))
+                               for i, column := range columns {
+                                       row[i] = data[column]
+                               }
+                       }
+               } else {
+                       row, err = mapper(reflect.ValueOf(val))
+               }
+               if err != nil {
+                       return fmt.Errorf("failed to map row %T: %v", val, err)
+               }
+               if err = writer.add(row); err != nil {
+                       return err
+               }
+               if err := writer.writeBatchIfNeeded(ctx, db); err != nil {
+                       return err
+               }
+       }
+
+       if err := writer.writeIfNeeded(ctx, db); err != nil {
+               return err
+       }
+
+       log.Infof(ctx, "written %v row(s) into %v", writer.totalCount, f.Table)
+       return nil
+}
diff --git a/sdks/go/pkg/beam/io/databaseio/loader.go 
b/sdks/go/pkg/beam/io/databaseio/loader.go
new file mode 100644
index 00000000000..9dabe17df92
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/loader.go
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also: 
https://golang.org/pkg/database/sql/
+package databaseio
+
+// MapLoader calls on LoadMap method with with a fetched row as map.
+type MapLoader interface {
+       LoadMap(row map[string]interface{}) error
+}
+
+// SliceLoader calls LoadSlice method with a fetched row as slice.
+type SliceLoader interface {
+       LoadSlice(row []interface{}) error
+}
diff --git a/sdks/go/pkg/beam/io/databaseio/mapper.go 
b/sdks/go/pkg/beam/io/databaseio/mapper.go
new file mode 100644
index 00000000000..7c55f80b22b
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/mapper.go
@@ -0,0 +1,132 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also: 
https://golang.org/pkg/database/sql/
+package databaseio
+
+import (
+       "database/sql"
+       "fmt"
+       "reflect"
+       "strings"
+       "time"
+)
+
+//rowMapper represents a record mapper
+type rowMapper func(value reflect.Value) ([]interface{}, error)
+
+//newQueryMapper creates a new record mapped
+func newQueryMapper(columns []string, columnTypes []*sql.ColumnType, 
recordType reflect.Type) (rowMapper, error) {
+       val := reflect.New(recordType).Interface()
+       if _, isLoader := val.(MapLoader); isLoader {
+               return newQueryLoaderMapper(columns, columnTypes)
+       } else if recordType.Kind() == reflect.Struct {
+               return newQueryStructMapper(columns, recordType)
+       }
+       return nil, fmt.Errorf("unsupported type %s", recordType)
+}
+
+//newQueryStructMapper creates a new record mapper for supplied struct type
+func newQueryStructMapper(columns []string, recordType reflect.Type) 
(rowMapper, error) {
+       mappedFieldIndex, err := mapFields(columns, recordType)
+       if err != nil {
+               return nil, err
+       }
+       var record = make([]interface{}, recordType.NumField())
+       var mapper = func(value reflect.Value) ([]interface{}, error) {
+               value = value.Elem() //T = *T
+               for i, fieldIndex := range mappedFieldIndex {
+                       record[i] = value.Field(fieldIndex).Addr().Interface()
+               }
+               return record, nil
+       }
+       return mapper, nil
+}
+
+//newQueryStructMapper creates a new record mapper for supplied struct type
+func newQueryLoaderMapper(columns []string, columnTypes []*sql.ColumnType) 
(rowMapper, error) {
+       var record = make([]interface{}, len(columns))
+       var valueProviders = make([]func(index int, values []interface{}), 
len(columns))
+       defaultProvider := func(index int, values []interface{}) {
+               val := new(interface{})
+               values[index] = &val
+       }
+       for i := range columns {
+               valueProviders[i] = defaultProvider
+               if len(columnTypes) == 0 {
+                       continue
+               }
+               dbTypeName := strings.ToLower(columnTypes[i].DatabaseTypeName())
+               if strings.Contains(dbTypeName, "char") || 
strings.Contains(dbTypeName, "string") || strings.Contains(dbTypeName, "text") {
+                       valueProviders[i] = func(index int, values 
[]interface{}) {
+                               val := ""
+                               values[index] = &val
+                       }
+               } else if strings.Contains(dbTypeName, "int") {
+                       valueProviders[i] = func(index int, values 
[]interface{}) {
+                               val := 0
+                               values[index] = &val
+                       }
+               } else if strings.Contains(dbTypeName, "decimal") || 
strings.Contains(dbTypeName, "numeric") || strings.Contains(dbTypeName, 
"float") {
+                       valueProviders[i] = func(index int, values 
[]interface{}) {
+                               val := 0.0
+                               values[index] = &val
+                       }
+               } else if strings.Contains(dbTypeName, "time") || 
strings.Contains(dbTypeName, "date") {
+                       valueProviders[i] = func(index int, values 
[]interface{}) {
+                               val := time.Now()
+                               values[index] = &val
+                       }
+               } else if strings.Contains(dbTypeName, "bool") {
+                       valueProviders[i] = func(index int, values 
[]interface{}) {
+                               val := false
+                               values[index] = &val
+                       }
+               } else {
+                       valueProviders[i] = func(index int, values 
[]interface{}) {
+                               val := 
reflect.New(columnTypes[i].ScanType()).Elem().Interface()
+                               values[index] = &val
+                       }
+               }
+       }
+       mapper := func(value reflect.Value) ([]interface{}, error) {
+               for i := range columns {
+                       valueProviders[i](i, record)
+               }
+               return record, nil
+       }
+       return mapper, nil
+}
+
+//newQueryMapper creates a new record mapped
+func newWriterRowMapper(columns []string, recordType reflect.Type) (rowMapper, 
error) {
+       mappedFieldIndex, err := mapFields(columns, recordType)
+       if err != nil {
+               return nil, err
+       }
+       columnCount := len(columns)
+       mapper := func(value reflect.Value) ([]interface{}, error) {
+               var record = make([]interface{}, columnCount)
+               if value.Kind() == reflect.Ptr {
+                       value = value.Elem() //T = *T
+               }
+               for i, fieldIndex := range mappedFieldIndex {
+                       record[i] = value.Field(fieldIndex).Interface()
+               }
+               return record, nil
+       }
+       return mapper, nil
+}
diff --git a/sdks/go/pkg/beam/io/databaseio/util.go 
b/sdks/go/pkg/beam/io/databaseio/util.go
new file mode 100644
index 00000000000..88888ba3d28
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/util.go
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also: 
https://golang.org/pkg/database/sql/
+package databaseio
+
+import (
+       "fmt"
+       "reflect"
+       "strings"
+)
+
+//mapFields maps column into field index in record type
+func mapFields(columns []string, recordType reflect.Type) ([]int, error) {
+       var indexedFields = map[string]int{}
+       for i := 0; i < recordType.NumField(); i++ {
+               if isExported := recordType.Field(i).PkgPath == ""; !isExported 
{
+                       continue
+               }
+               fieldName := recordType.Field(i).Name
+               indexedFields[fieldName] = i
+               indexedFields[strings.ToLower(fieldName)] = i //to account for 
various matching strategies
+               aTag := recordType.Field(i).Tag
+               if column := aTag.Get("column"); column != "" {
+                       indexedFields[column] = i
+               }
+       }
+       var mappedFieldIndex = make([]int, len(columns))
+       for i, column := range columns {
+               fieldIndex, ok := indexedFields[column]
+               if !ok {
+                       fieldIndex, ok = indexedFields[strings.ToLower(column)]
+               }
+               if !ok {
+                       fieldIndex, ok = 
indexedFields[strings.Replace(strings.ToLower(column), "_", "", 
strings.Count(column, "_"))]
+               }
+               if !ok {
+                       return nil, fmt.Errorf("failed to matched a %v field 
for SQL column: %v", recordType, column)
+               }
+               mappedFieldIndex[i] = fieldIndex
+       }
+       return mappedFieldIndex, nil
+}
+
+func asDereferenceSlice(aSlice []interface{}) {
+       for i, value := range aSlice {
+               if value == nil {
+                       continue
+               }
+               aSlice[i] = reflect.ValueOf(value).Elem().Interface()
+
+       }
+}
+
+func asMap(keys []string, values []interface{}) map[string]interface{} {
+       var result = make(map[string]interface{})
+       for i, key := range keys {
+               result[key] = values[i]
+       }
+       return result
+}
diff --git a/sdks/go/pkg/beam/io/databaseio/util_test.go 
b/sdks/go/pkg/beam/io/databaseio/util_test.go
new file mode 100644
index 00000000000..2b4030107c9
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/util_test.go
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also: 
https://golang.org/pkg/database/sql/
+package databaseio
+
+import (
+       "github.com/stretchr/testify/assert"
+       "reflect"
+       "testing"
+       "time"
+)
+
+func Test_queryRecordMapperProvider(t *testing.T) {
+
+       type User struct {
+               ID          int
+               local       bool
+               DateOfBirth time.Time
+               NameTest    string `column:"name"`
+               Random      float64
+       }
+
+       mapper, err := newQueryMapper([]string{"id", "name", "random", 
"date_of_birth"}, nil, reflect.TypeOf(User{}))
+       if !assert.Nil(t, err) {
+               return
+       }
+
+       aUser := &User{}
+       record, err := mapper(reflect.ValueOf(aUser))
+       if !assert.Nil(t, err) {
+               return
+       }
+       id, ok := record[0].(*int)
+       assert.True(t, ok)
+       *id = 10
+
+       name, ok := record[1].(*string)
+       assert.True(t, ok)
+       *name = "test"
+
+       random, ok := record[2].(*float64)
+       assert.True(t, ok)
+       *random = 1.2
+
+       dob, ok := record[3].(*time.Time)
+       assert.True(t, ok)
+       now := time.Now()
+       *dob = now
+
+       assert.EqualValues(t, &User{
+               ID:          *id,
+               DateOfBirth: *dob,
+               NameTest:    *name,
+               Random:      *random,
+       }, aUser)
+}
+
+func Test_writerRecordMapperProvider(t *testing.T) {
+       type User struct {
+               ID          int
+               local       bool
+               DateOfBirth time.Time
+               NameTest    string `column:"name"`
+               Random      float64
+       }
+
+       mapper, err := newWriterRowMapper([]string{"id", "name", "random", 
"date_of_birth"}, reflect.TypeOf(User{}))
+       if !assert.Nil(t, err) {
+               return
+       }
+       aUser := &User{
+               ID:          2,
+               NameTest:    "abc",
+               Random:      1.6,
+               DateOfBirth: time.Now(),
+       }
+       record, err := mapper(reflect.ValueOf(aUser))
+       if !assert.Nil(t, err) {
+               return
+       }
+       assert.EqualValues(t, 2, record[0])
+       assert.EqualValues(t, "abc", record[1])
+       assert.EqualValues(t, 1.6, record[2])
+       assert.EqualValues(t, aUser.DateOfBirth, record[3])
+
+}
diff --git a/sdks/go/pkg/beam/io/databaseio/writer.go 
b/sdks/go/pkg/beam/io/databaseio/writer.go
new file mode 100644
index 00000000000..7b5a6381eb5
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/writer.go
@@ -0,0 +1,96 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also: 
https://golang.org/pkg/database/sql/
+package databaseio
+
+import (
+       "database/sql"
+       "fmt"
+       "golang.org/x/net/context"
+       "strings"
+)
+
+// Writer returns a row of data to be inserted into a table.
+type Writer interface {
+       SaveData() (map[string]interface{}, error)
+}
+
+type writer struct {
+       batchSize    int
+       table        string
+       sqlTemplate  string
+       valueTempate string
+       binding      []interface{}
+       columnCount  int
+       rowCount     int
+       totalCount   int
+}
+
+func (w *writer) add(row []interface{}) error {
+       w.rowCount++
+       w.totalCount++
+       if len(row) != w.columnCount {
+               return fmt.Errorf("expected %v row values, but had: %v", 
w.columnCount, len(row))
+       }
+       w.binding = append(w.binding, row...)
+       return nil
+}
+
+func (w *writer) write(ctx context.Context, db *sql.DB) error {
+       values := strings.Repeat(w.valueTempate+",", w.rowCount)
+       SQL := w.sqlTemplate + string(values[:len(values)-1])
+       resultSet, err := db.ExecContext(ctx, SQL, w.binding...)
+       if err != nil {
+               return err
+       }
+       affected, _ := resultSet.RowsAffected()
+       if int(affected) != w.rowCount {
+               return fmt.Errorf("expected to write: %v, but written: %v", 
w.rowCount, affected)
+       }
+       w.binding = []interface{}{}
+       w.rowCount = 0
+       return nil
+}
+
+func (w *writer) writeBatchIfNeeded(ctx context.Context, db *sql.DB) error {
+       if w.rowCount >= w.batchSize {
+               return w.write(ctx, db)
+       }
+       return nil
+}
+
+func (w *writer) writeIfNeeded(ctx context.Context, db *sql.DB) error {
+       if w.rowCount >= 0 {
+               return w.write(ctx, db)
+       }
+       return nil
+}
+
+func newWriter(batchSize int, table string, columns []string) (*writer, error) 
{
+       if len(columns) == 0 {
+               return nil, fmt.Errorf("columns were empty")
+       }
+       values := strings.Repeat("?,", len(columns))
+       return &writer{
+               batchSize:    batchSize,
+               columnCount:  len(columns),
+               table:        table,
+               binding:      make([]interface{}, 0),
+               sqlTemplate:  fmt.Sprintf("INSERT INTO %v(%v) VALUES", table, 
strings.Join(columns, ",")),
+               valueTempate: fmt.Sprintf("(%s)", values[:len(values)-1]),
+       }, nil
+}
diff --git a/sdks/python/apache_beam/testing/util_test.py 
b/sdks/python/apache_beam/testing/util_test.py
index f46063e0871..83b68e811d8 100644
--- a/sdks/python/apache_beam/testing/util_test.py
+++ b/sdks/python/apache_beam/testing/util_test.py
@@ -38,6 +38,14 @@ def test_assert_that_passes(self):
     with TestPipeline() as p:
       assert_that(p | Create([1, 2, 3]), equal_to([1, 2, 3]))
 
+  def test_assert_that_passes_order_does_not_matter(self):
+    with TestPipeline() as p:
+      assert_that(p | Create([1, 2, 3]), equal_to([2, 1, 3]))
+
+  def test_assert_that_passes_order_does_not_matter_with_negatives(self):
+    with TestPipeline() as p:
+      assert_that(p | Create([1, -2, 3]), equal_to([-2, 1, 3]))
+
   def test_assert_that_passes_empty_equal_to(self):
     with TestPipeline() as p:
       assert_that(p | Create([]), equal_to([]))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


Issue Time Tracking
-------------------

    Worklog Id:     (was: 166218)
    Time Spent: 1h 40m  (was: 1.5h)

> Create ability to read/write database implementing database/sql  contract
> -------------------------------------------------------------------------
>
>                 Key: BEAM-5729
>                 URL: https://issues.apache.org/jira/browse/BEAM-5729
>             Project: Beam
>          Issue Type: Improvement
>          Components: sdk-go
>    Affects Versions: 2.7.0
>            Reporter: Adrian Witas
>            Priority: Major
>          Time Spent: 1h 40m
>  Remaining Estimate: 0h
>




--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to