[
https://issues.apache.org/jira/browse/BEAM-5729?focusedWorklogId=166218&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-166218
]
ASF GitHub Bot logged work on BEAM-5729:
----------------------------------------
Author: ASF GitHub Bot
Created on: 15/Nov/18 00:25
Start Date: 15/Nov/18 00:25
Worklog Time Spent: 10m
Work Description: chamikaramj closed pull request #6676: [BEAM-5729]
added database/sql reader/writer
URL: https://github.com/apache/beam/pull/6676
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/sdks/go/pkg/beam/io/databaseio/database.go
b/sdks/go/pkg/beam/io/databaseio/database.go
new file mode 100644
index 00000000000..e9de9cbbbe9
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/database.go
@@ -0,0 +1,219 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also:
https://golang.org/pkg/database/sql/
+package databaseio
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "github.com/apache/beam/sdks/go/pkg/beam"
+ "github.com/apache/beam/sdks/go/pkg/beam/log"
+ "reflect"
+ "strings"
+)
+
+func init() {
+ beam.RegisterType(reflect.TypeOf((*queryFn)(nil)).Elem())
+ beam.RegisterType(reflect.TypeOf((*writeFn)(nil)).Elem())
+}
+
+// writeSizeLimit is the maximum number of rows allowed to a write.
+const writeRowLimit = 1000
+
+// Read reads all rows from the given table. The table must have a schema
+// compatible with the given type, t, and Read returns a PCollection<t>. If the
+// table has more rows than t, then Read is implicitly a projection.
+func Read(s beam.Scope, driver, dsn, table string, t reflect.Type)
beam.PCollection {
+ s = s.Scope(driver + ".Read")
+ return query(s, driver, dsn, fmt.Sprintf("SELECT * from %v", table), t)
+}
+
+// Query executes a query. The output must have a schema compatible with the
given
+// type, t. It returns a PCollection<t>.
+func Query(s beam.Scope, driver, dsn, q string, t reflect.Type)
beam.PCollection {
+ s = s.Scope(driver + ".Query")
+ return query(s, driver, dsn, q, t)
+}
+
+func query(s beam.Scope, driver, dsn, query string, t reflect.Type)
beam.PCollection {
+ imp := beam.Impulse(s)
+ return beam.ParDo(s, &queryFn{Driver: driver, Dsn: dsn, Query: query,
Type: beam.EncodedType{T: t}}, imp, beam.TypeDefinition{Var: beam.XType, T: t})
+}
+
+type queryFn struct {
+ // Project is the project
+ Driver string `json:"driver"`
+ // Project is the project
+ Dsn string `json:"dsn"`
+ // Table is the table identifier.
+ Query string `json:"query"`
+ // Type is the encoded schema type.
+ Type beam.EncodedType `json:"type"`
+}
+
+func (f *queryFn) ProcessElement(ctx context.Context, _ []byte, emit
func(beam.X)) error {
+ //TODO move DB Open and Close to Setup and Teardown methods or
StartBundle and FinishBundle
+ db, err := sql.Open(f.Driver, f.Dsn)
+ if err != nil {
+ return fmt.Errorf("failed to open database: %v, %v", f.Driver,
err)
+ }
+ defer db.Close()
+ statement, err := db.PrepareContext(ctx, f.Query)
+ if err != nil {
+ return fmt.Errorf("failed to prepare query: %v, %v", f.Query,
err)
+ }
+ defer statement.Close()
+ rows, err := statement.QueryContext(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to run query: %v, %v", f.Query, err)
+ }
+ defer rows.Close()
+ var mapper rowMapper
+ var columns []string
+ for rows.Next() {
+ reflectRow := reflect.New(f.Type.T)
+ row := reflectRow.Interface() // row : *T
+ if mapper == nil {
+ columns, err = rows.Columns()
+ if err != nil {
+ return err
+ }
+ columnsTypes, _ := rows.ColumnTypes()
+ if mapper, err = newQueryMapper(columns, columnsTypes,
f.Type.T); err != nil {
+ return fmt.Errorf("failed to create rowValues
mapper: %v", err)
+ }
+ }
+ rowValues, err := mapper(reflectRow)
+ if err != nil {
+ return err
+ }
+ err = rows.Scan(rowValues...)
+ if err != nil {
+ return fmt.Errorf("failed to scan %v, %v", f.Query, err)
+ }
+ if loader, ok := row.(MapLoader); ok {
+ asDereferenceSlice(rowValues)
+ loader.LoadMap(asMap(columns, rowValues))
+ } else if loader, ok := row.(SliceLoader); ok {
+ asDereferenceSlice(rowValues)
+ loader.LoadSlice(rowValues)
+ }
+ emit(reflect.ValueOf(row).Elem().Interface()) // emit(*row)
+ }
+ return nil
+}
+
+// Write writes the elements of the given PCollection<T> to database, if
columns left empty all table columns are used to insert into, otherwise selected
+func Write(s beam.Scope, driver, dsn, table string, columns []string, col
beam.PCollection) {
+ t := col.Type().Type()
+ s = s.Scope(driver + ".Write")
+ pre := beam.AddFixedKey(s, col)
+ post := beam.GroupByKey(s, pre)
+ beam.ParDo0(s, &writeFn{Driver: driver, Dsn: dsn, Table: table,
Columns: columns, BatchSize: writeRowLimit, Type: beam.EncodedType{T: t}}, post)
+}
+
+// WriteWithBatchSize writes the elements of the given PCollection<T> to
database with custom batch size. Batch size control number of elements in the
batch INSERT statement.
+func WriteWithBatchSize(s beam.Scope, batchSize int, driver, dsn, table
string, columns []string, col beam.PCollection) {
+ t := col.Type().Type()
+ s = s.Scope(driver + ".Write")
+ pre := beam.AddFixedKey(s, col)
+ post := beam.GroupByKey(s, pre)
+ beam.ParDo0(s, &writeFn{Driver: driver, Dsn: dsn, Table: table,
Columns: columns, BatchSize: batchSize, Type: beam.EncodedType{T: t}}, post)
+}
+
+type writeFn struct {
+ // Project is the project
+ Driver string `json:"driver"`
+ // Project is the project
+ Dsn string `json:"dsn"`
+ // Table is the table identifier.
+ Table string `json:"table"`
+ // Columns to inserts, if empty then all columns
+ Columns []string `json:"columns"`
+ //BatchSize size
+ BatchSize int `json:"batchSize"`
+ // Type is the encoded schema type.
+ Type beam.EncodedType `json:"type"`
+}
+
+func (f *writeFn) ProcessElement(ctx context.Context, _ int, iter
func(*beam.X) bool) error {
+ //TODO move DB Open and Close to Setup and Teardown methods or
StartBundle and FinishBundle
+ db, err := sql.Open(f.Driver, f.Dsn)
+ if err != nil {
+ return fmt.Errorf("failed to open database: %v, %v", f.Driver,
err)
+ }
+ defer db.Close()
+ projection := "*"
+ if len(f.Columns) > 0 {
+ projection = strings.Join(f.Columns, ",")
+ }
+ dql := fmt.Sprintf("SELECT %v FROM %v WHERE 1 = 0", projection,
f.Table)
+ query, err := db.Prepare(dql)
+ if err != nil {
+ return fmt.Errorf("failed to prepare query: %v, %v", f.Table,
err)
+ }
+ defer query.Close()
+ rows, err := query.Query()
+ if err != nil {
+ return fmt.Errorf("failed to query: %v, %v", f.Table, err)
+ }
+ columns, err := rows.Columns()
+ if err != nil {
+ return fmt.Errorf("failed to discover column: %v, %v", f.Table,
err)
+ }
+ //TODO move to Setup methods
+ mapper, err := newWriterRowMapper(columns, f.Type.T)
+ if err != nil {
+ return fmt.Errorf("failed to create row mapper: %v", err)
+ }
+ writer, err := newWriter(f.BatchSize, f.Table, columns)
+ if err != nil {
+ return err
+ }
+ var val beam.X
+ for iter(&val) {
+ var row []interface{}
+ var data map[string]interface{}
+ if writer, ok := val.(Writer); ok {
+ if data, err = writer.SaveData(); err == nil {
+ row = make([]interface{}, len(columns))
+ for i, column := range columns {
+ row[i] = data[column]
+ }
+ }
+ } else {
+ row, err = mapper(reflect.ValueOf(val))
+ }
+ if err != nil {
+ return fmt.Errorf("failed to map row %T: %v", val, err)
+ }
+ if err = writer.add(row); err != nil {
+ return err
+ }
+ if err := writer.writeBatchIfNeeded(ctx, db); err != nil {
+ return err
+ }
+ }
+
+ if err := writer.writeIfNeeded(ctx, db); err != nil {
+ return err
+ }
+
+ log.Infof(ctx, "written %v row(s) into %v", writer.totalCount, f.Table)
+ return nil
+}
diff --git a/sdks/go/pkg/beam/io/databaseio/loader.go
b/sdks/go/pkg/beam/io/databaseio/loader.go
new file mode 100644
index 00000000000..9dabe17df92
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/loader.go
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also:
https://golang.org/pkg/database/sql/
+package databaseio
+
+// MapLoader calls on LoadMap method with with a fetched row as map.
+type MapLoader interface {
+ LoadMap(row map[string]interface{}) error
+}
+
+// SliceLoader calls LoadSlice method with a fetched row as slice.
+type SliceLoader interface {
+ LoadSlice(row []interface{}) error
+}
diff --git a/sdks/go/pkg/beam/io/databaseio/mapper.go
b/sdks/go/pkg/beam/io/databaseio/mapper.go
new file mode 100644
index 00000000000..7c55f80b22b
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/mapper.go
@@ -0,0 +1,132 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also:
https://golang.org/pkg/database/sql/
+package databaseio
+
+import (
+ "database/sql"
+ "fmt"
+ "reflect"
+ "strings"
+ "time"
+)
+
+//rowMapper represents a record mapper
+type rowMapper func(value reflect.Value) ([]interface{}, error)
+
+//newQueryMapper creates a new record mapped
+func newQueryMapper(columns []string, columnTypes []*sql.ColumnType,
recordType reflect.Type) (rowMapper, error) {
+ val := reflect.New(recordType).Interface()
+ if _, isLoader := val.(MapLoader); isLoader {
+ return newQueryLoaderMapper(columns, columnTypes)
+ } else if recordType.Kind() == reflect.Struct {
+ return newQueryStructMapper(columns, recordType)
+ }
+ return nil, fmt.Errorf("unsupported type %s", recordType)
+}
+
+//newQueryStructMapper creates a new record mapper for supplied struct type
+func newQueryStructMapper(columns []string, recordType reflect.Type)
(rowMapper, error) {
+ mappedFieldIndex, err := mapFields(columns, recordType)
+ if err != nil {
+ return nil, err
+ }
+ var record = make([]interface{}, recordType.NumField())
+ var mapper = func(value reflect.Value) ([]interface{}, error) {
+ value = value.Elem() //T = *T
+ for i, fieldIndex := range mappedFieldIndex {
+ record[i] = value.Field(fieldIndex).Addr().Interface()
+ }
+ return record, nil
+ }
+ return mapper, nil
+}
+
+//newQueryStructMapper creates a new record mapper for supplied struct type
+func newQueryLoaderMapper(columns []string, columnTypes []*sql.ColumnType)
(rowMapper, error) {
+ var record = make([]interface{}, len(columns))
+ var valueProviders = make([]func(index int, values []interface{}),
len(columns))
+ defaultProvider := func(index int, values []interface{}) {
+ val := new(interface{})
+ values[index] = &val
+ }
+ for i := range columns {
+ valueProviders[i] = defaultProvider
+ if len(columnTypes) == 0 {
+ continue
+ }
+ dbTypeName := strings.ToLower(columnTypes[i].DatabaseTypeName())
+ if strings.Contains(dbTypeName, "char") ||
strings.Contains(dbTypeName, "string") || strings.Contains(dbTypeName, "text") {
+ valueProviders[i] = func(index int, values
[]interface{}) {
+ val := ""
+ values[index] = &val
+ }
+ } else if strings.Contains(dbTypeName, "int") {
+ valueProviders[i] = func(index int, values
[]interface{}) {
+ val := 0
+ values[index] = &val
+ }
+ } else if strings.Contains(dbTypeName, "decimal") ||
strings.Contains(dbTypeName, "numeric") || strings.Contains(dbTypeName,
"float") {
+ valueProviders[i] = func(index int, values
[]interface{}) {
+ val := 0.0
+ values[index] = &val
+ }
+ } else if strings.Contains(dbTypeName, "time") ||
strings.Contains(dbTypeName, "date") {
+ valueProviders[i] = func(index int, values
[]interface{}) {
+ val := time.Now()
+ values[index] = &val
+ }
+ } else if strings.Contains(dbTypeName, "bool") {
+ valueProviders[i] = func(index int, values
[]interface{}) {
+ val := false
+ values[index] = &val
+ }
+ } else {
+ valueProviders[i] = func(index int, values
[]interface{}) {
+ val :=
reflect.New(columnTypes[i].ScanType()).Elem().Interface()
+ values[index] = &val
+ }
+ }
+ }
+ mapper := func(value reflect.Value) ([]interface{}, error) {
+ for i := range columns {
+ valueProviders[i](i, record)
+ }
+ return record, nil
+ }
+ return mapper, nil
+}
+
+//newQueryMapper creates a new record mapped
+func newWriterRowMapper(columns []string, recordType reflect.Type) (rowMapper,
error) {
+ mappedFieldIndex, err := mapFields(columns, recordType)
+ if err != nil {
+ return nil, err
+ }
+ columnCount := len(columns)
+ mapper := func(value reflect.Value) ([]interface{}, error) {
+ var record = make([]interface{}, columnCount)
+ if value.Kind() == reflect.Ptr {
+ value = value.Elem() //T = *T
+ }
+ for i, fieldIndex := range mappedFieldIndex {
+ record[i] = value.Field(fieldIndex).Interface()
+ }
+ return record, nil
+ }
+ return mapper, nil
+}
diff --git a/sdks/go/pkg/beam/io/databaseio/util.go
b/sdks/go/pkg/beam/io/databaseio/util.go
new file mode 100644
index 00000000000..88888ba3d28
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/util.go
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also:
https://golang.org/pkg/database/sql/
+package databaseio
+
+import (
+ "fmt"
+ "reflect"
+ "strings"
+)
+
+//mapFields maps column into field index in record type
+func mapFields(columns []string, recordType reflect.Type) ([]int, error) {
+ var indexedFields = map[string]int{}
+ for i := 0; i < recordType.NumField(); i++ {
+ if isExported := recordType.Field(i).PkgPath == ""; !isExported
{
+ continue
+ }
+ fieldName := recordType.Field(i).Name
+ indexedFields[fieldName] = i
+ indexedFields[strings.ToLower(fieldName)] = i //to account for
various matching strategies
+ aTag := recordType.Field(i).Tag
+ if column := aTag.Get("column"); column != "" {
+ indexedFields[column] = i
+ }
+ }
+ var mappedFieldIndex = make([]int, len(columns))
+ for i, column := range columns {
+ fieldIndex, ok := indexedFields[column]
+ if !ok {
+ fieldIndex, ok = indexedFields[strings.ToLower(column)]
+ }
+ if !ok {
+ fieldIndex, ok =
indexedFields[strings.Replace(strings.ToLower(column), "_", "",
strings.Count(column, "_"))]
+ }
+ if !ok {
+ return nil, fmt.Errorf("failed to matched a %v field
for SQL column: %v", recordType, column)
+ }
+ mappedFieldIndex[i] = fieldIndex
+ }
+ return mappedFieldIndex, nil
+}
+
+func asDereferenceSlice(aSlice []interface{}) {
+ for i, value := range aSlice {
+ if value == nil {
+ continue
+ }
+ aSlice[i] = reflect.ValueOf(value).Elem().Interface()
+
+ }
+}
+
+func asMap(keys []string, values []interface{}) map[string]interface{} {
+ var result = make(map[string]interface{})
+ for i, key := range keys {
+ result[key] = values[i]
+ }
+ return result
+}
diff --git a/sdks/go/pkg/beam/io/databaseio/util_test.go
b/sdks/go/pkg/beam/io/databaseio/util_test.go
new file mode 100644
index 00000000000..2b4030107c9
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/util_test.go
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also:
https://golang.org/pkg/database/sql/
+package databaseio
+
+import (
+ "github.com/stretchr/testify/assert"
+ "reflect"
+ "testing"
+ "time"
+)
+
+func Test_queryRecordMapperProvider(t *testing.T) {
+
+ type User struct {
+ ID int
+ local bool
+ DateOfBirth time.Time
+ NameTest string `column:"name"`
+ Random float64
+ }
+
+ mapper, err := newQueryMapper([]string{"id", "name", "random",
"date_of_birth"}, nil, reflect.TypeOf(User{}))
+ if !assert.Nil(t, err) {
+ return
+ }
+
+ aUser := &User{}
+ record, err := mapper(reflect.ValueOf(aUser))
+ if !assert.Nil(t, err) {
+ return
+ }
+ id, ok := record[0].(*int)
+ assert.True(t, ok)
+ *id = 10
+
+ name, ok := record[1].(*string)
+ assert.True(t, ok)
+ *name = "test"
+
+ random, ok := record[2].(*float64)
+ assert.True(t, ok)
+ *random = 1.2
+
+ dob, ok := record[3].(*time.Time)
+ assert.True(t, ok)
+ now := time.Now()
+ *dob = now
+
+ assert.EqualValues(t, &User{
+ ID: *id,
+ DateOfBirth: *dob,
+ NameTest: *name,
+ Random: *random,
+ }, aUser)
+}
+
+func Test_writerRecordMapperProvider(t *testing.T) {
+ type User struct {
+ ID int
+ local bool
+ DateOfBirth time.Time
+ NameTest string `column:"name"`
+ Random float64
+ }
+
+ mapper, err := newWriterRowMapper([]string{"id", "name", "random",
"date_of_birth"}, reflect.TypeOf(User{}))
+ if !assert.Nil(t, err) {
+ return
+ }
+ aUser := &User{
+ ID: 2,
+ NameTest: "abc",
+ Random: 1.6,
+ DateOfBirth: time.Now(),
+ }
+ record, err := mapper(reflect.ValueOf(aUser))
+ if !assert.Nil(t, err) {
+ return
+ }
+ assert.EqualValues(t, 2, record[0])
+ assert.EqualValues(t, "abc", record[1])
+ assert.EqualValues(t, 1.6, record[2])
+ assert.EqualValues(t, aUser.DateOfBirth, record[3])
+
+}
diff --git a/sdks/go/pkg/beam/io/databaseio/writer.go
b/sdks/go/pkg/beam/io/databaseio/writer.go
new file mode 100644
index 00000000000..7b5a6381eb5
--- /dev/null
+++ b/sdks/go/pkg/beam/io/databaseio/writer.go
@@ -0,0 +1,96 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package databaseio provides transformations and utilities to interact with
+// generic database database/sql API. See also:
https://golang.org/pkg/database/sql/
+package databaseio
+
+import (
+ "database/sql"
+ "fmt"
+ "golang.org/x/net/context"
+ "strings"
+)
+
+// Writer returns a row of data to be inserted into a table.
+type Writer interface {
+ SaveData() (map[string]interface{}, error)
+}
+
+type writer struct {
+ batchSize int
+ table string
+ sqlTemplate string
+ valueTempate string
+ binding []interface{}
+ columnCount int
+ rowCount int
+ totalCount int
+}
+
+func (w *writer) add(row []interface{}) error {
+ w.rowCount++
+ w.totalCount++
+ if len(row) != w.columnCount {
+ return fmt.Errorf("expected %v row values, but had: %v",
w.columnCount, len(row))
+ }
+ w.binding = append(w.binding, row...)
+ return nil
+}
+
+func (w *writer) write(ctx context.Context, db *sql.DB) error {
+ values := strings.Repeat(w.valueTempate+",", w.rowCount)
+ SQL := w.sqlTemplate + string(values[:len(values)-1])
+ resultSet, err := db.ExecContext(ctx, SQL, w.binding...)
+ if err != nil {
+ return err
+ }
+ affected, _ := resultSet.RowsAffected()
+ if int(affected) != w.rowCount {
+ return fmt.Errorf("expected to write: %v, but written: %v",
w.rowCount, affected)
+ }
+ w.binding = []interface{}{}
+ w.rowCount = 0
+ return nil
+}
+
+func (w *writer) writeBatchIfNeeded(ctx context.Context, db *sql.DB) error {
+ if w.rowCount >= w.batchSize {
+ return w.write(ctx, db)
+ }
+ return nil
+}
+
+func (w *writer) writeIfNeeded(ctx context.Context, db *sql.DB) error {
+ if w.rowCount >= 0 {
+ return w.write(ctx, db)
+ }
+ return nil
+}
+
+func newWriter(batchSize int, table string, columns []string) (*writer, error)
{
+ if len(columns) == 0 {
+ return nil, fmt.Errorf("columns were empty")
+ }
+ values := strings.Repeat("?,", len(columns))
+ return &writer{
+ batchSize: batchSize,
+ columnCount: len(columns),
+ table: table,
+ binding: make([]interface{}, 0),
+ sqlTemplate: fmt.Sprintf("INSERT INTO %v(%v) VALUES", table,
strings.Join(columns, ",")),
+ valueTempate: fmt.Sprintf("(%s)", values[:len(values)-1]),
+ }, nil
+}
diff --git a/sdks/python/apache_beam/testing/util_test.py
b/sdks/python/apache_beam/testing/util_test.py
index f46063e0871..83b68e811d8 100644
--- a/sdks/python/apache_beam/testing/util_test.py
+++ b/sdks/python/apache_beam/testing/util_test.py
@@ -38,6 +38,14 @@ def test_assert_that_passes(self):
with TestPipeline() as p:
assert_that(p | Create([1, 2, 3]), equal_to([1, 2, 3]))
+ def test_assert_that_passes_order_does_not_matter(self):
+ with TestPipeline() as p:
+ assert_that(p | Create([1, 2, 3]), equal_to([2, 1, 3]))
+
+ def test_assert_that_passes_order_does_not_matter_with_negatives(self):
+ with TestPipeline() as p:
+ assert_that(p | Create([1, -2, 3]), equal_to([-2, 1, 3]))
+
def test_assert_that_passes_empty_equal_to(self):
with TestPipeline() as p:
assert_that(p | Create([]), equal_to([]))
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
Issue Time Tracking
-------------------
Worklog Id: (was: 166218)
Time Spent: 1h 40m (was: 1.5h)
> Create ability to read/write database implementing database/sql contract
> -------------------------------------------------------------------------
>
> Key: BEAM-5729
> URL: https://issues.apache.org/jira/browse/BEAM-5729
> Project: Beam
> Issue Type: Improvement
> Components: sdk-go
> Affects Versions: 2.7.0
> Reporter: Adrian Witas
> Priority: Major
> Time Spent: 1h 40m
> Remaining Estimate: 0h
>
--
This message was sent by Atlassian JIRA
(v7.6.3#76005)