shamrickus commented on a change in pull request #5924:
URL: https://github.com/apache/trafficcontrol/pull/5924#discussion_r651935876



##########
File path: tools/traffic_vault_migrate/postgres.go
##########
@@ -0,0 +1,743 @@
+package main
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import (
+       "database/sql"
+       "encoding/base64"
+       "encoding/json"
+       "errors"
+       "fmt"
+       "log"
+       "strconv"
+       "strings"
+
+       "github.com/apache/trafficcontrol/lib/go-tc"
+       util "github.com/apache/trafficcontrol/lib/go-util"
+
+       _ "github.com/lib/pq"
+)
+
+// PGConfig represents the configuration options available to the PG backend
+type PGConfig struct {
+       Host     string `json:"host"`
+       Port     string `json:"port"`
+       User     string `json:"user"`
+       Password string `json:"password"`
+       SSLMode  string `json:"sslmode"`
+       Database string `json:"database"`
+       Key      string `json:"aesKey"`
+       AESKey   []byte
+}
+
+// PGBackend is the Postgres implementation of TVBackend
+type PGBackend struct {
+       sslKey pgSSLKeyTable
+       dnssec pgDNSSecTable
+       uri    pgURISignKeyTable
+       url    pgURLSigKeyTable
+       cfg    PGConfig
+       db     *sql.DB
+}
+
+// String returns a high level overview of the backend and its keys
+func (pg *PGBackend) String() string {
+       data := fmt.Sprintf("PG server %v@%v:%v\n", pg.cfg.User, pg.cfg.Host, 
pg.cfg.Port)
+       data += fmt.Sprintf("\tSSL Keys: %v\n", len(pg.sslKey.Records))
+       data += fmt.Sprintf("\tDNSSec Keys: %v\n", len(pg.dnssec.Records))
+       data += fmt.Sprintf("\tURI Keys: %v\n", len(pg.uri.Records))
+       data += fmt.Sprintf("\tURL Keys: %v\n", len(pg.url.Records))
+       return data
+}
+
+// Name returns the name for this backend
+func (pg *PGBackend) Name() string {
+       return "PG"
+}
+
+// ReadConfig takes in a filename and will read it into the backends config
+func (pg *PGBackend) ReadConfig(s string) error {
+       err := UnmarshalConfig(s, &pg.cfg)
+       if err != nil {
+               return err
+       }
+
+       pg.cfg.AESKey, err = base64.StdEncoding.DecodeString(pg.cfg.Key)
+       if err != nil {
+               return fmt.Errorf("unable to decode PG AESKey: %w", err)
+       }
+       return nil
+}
+
+// Insert takes the current keys and inserts them into the backend DB
+func (pg *PGBackend) Insert() error {
+       if err := pg.sslKey.insertKeys(pg.db); err != nil {
+               return err
+       }
+       if err := pg.dnssec.insertKeys(pg.db); err != nil {
+               return err
+       }
+       if err := pg.url.insertKeys(pg.db); err != nil {
+               return err
+       }
+       if err := pg.uri.insertKeys(pg.db); err != nil {
+               return err
+       }
+       return nil
+}
+
+// Start initiates the connection to the backend DB
+func (pg *PGBackend) Start() error {
+       sqlStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", 
pg.cfg.User, pg.cfg.Password, pg.cfg.Host, pg.cfg.Port, pg.cfg.Database, 
pg.cfg.SSLMode)
+       db, err := sql.Open("postgres", sqlStr)
+       if err != nil {
+               return fmt.Errorf("unable to start PG client: %w", err)
+       }
+
+       pg.db = db
+       pg.sslKey = pgSSLKeyTable{}
+       pg.dnssec = pgDNSSecTable{}
+       pg.url = pgURLSigKeyTable{}
+       pg.uri = pgURISignKeyTable{}
+
+       return nil
+}
+
+// ValidateKey validates that the keys are valid (in most cases, certain 
fields are not null)
+func (pg *PGBackend) ValidateKey() []string {
+       var errors []string
+       if errs := pg.sslKey.validate(); errs != nil {
+               errors = append(errors, errs...)
+       }
+       if errs := pg.dnssec.validate(); errs != nil {
+               errors = append(errors, errs...)
+       }
+       if errs := pg.uri.validate(); errs != nil {
+               errors = append(errors, errs...)
+       }
+       if errs := pg.url.validate(); errs != nil {
+               errors = append(errors, errs...)
+       }
+       return errors
+}
+
+// Close terminates the connection to the backend DB
+func (pg *PGBackend) Close() error {
+       return pg.db.Close()
+}
+
+// Ping checks the connection to the backend DB
+func (pg *PGBackend) Ping() error {
+       return pg.db.Ping()
+}
+
+// Fetch gets all of the keys from the backend DB
+func (pg *PGBackend) Fetch() error {
+       if err := pg.sslKey.gatherKeys(pg.db); err != nil {
+               return err
+       }
+
+       if err := pg.dnssec.gatherKeys(pg.db); err != nil {
+               return err
+       }
+
+       if err := pg.url.gatherKeys(pg.db); err != nil {
+               return err
+       }
+
+       if err := pg.uri.gatherKeys(pg.db); err != nil {
+               return err
+       }
+
+       return nil
+}
+
+// GetSSLKeys converts the backends internal key representation into the 
common representation (SSLKey)
+func (pg *PGBackend) GetSSLKeys() ([]SSLKey, error) {
+       if err := pg.sslKey.decrypt(pg.cfg.AESKey); err != nil {
+               return nil, err
+       }
+       return pg.sslKey.toGeneric(), nil
+}
+
+// SetSSLKeys takes in keys and converts & encrypts the data into the backends 
internal format
+func (pg *PGBackend) SetSSLKeys(keys []SSLKey) error {
+       pg.sslKey.fromGeneric(keys)
+       return pg.sslKey.encrypt(pg.cfg.AESKey)
+}
+
+// GetDNSSecKeys converts the backends internal key representation into the 
common representation (DNSSecKey)
+func (pg *PGBackend) GetDNSSecKeys() ([]DNSSecKey, error) {
+       if err := pg.dnssec.decrypt(pg.cfg.AESKey); err != nil {
+               return nil, err
+       }
+       return pg.dnssec.toGeneric(), nil
+}
+
+// SetDNSSecKeys takes in keys and converts & encrypts the data into the 
backends internal format
+func (pg *PGBackend) SetDNSSecKeys(keys []DNSSecKey) error {
+       pg.dnssec.fromGeneric(keys)
+       return pg.dnssec.encrypt(pg.cfg.AESKey)
+}
+
+// GetURISignKeys converts the pg internal key representation into the common 
representation (URISignKey)
+func (pg *PGBackend) GetURISignKeys() ([]URISignKey, error) {
+       if err := pg.uri.decrypt(pg.cfg.AESKey); err != nil {
+               return nil, err
+       }
+       return pg.uri.toGeneric(), nil
+}
+
+// SetURISignKeys takes in keys and converts & encrypts the data into the 
backends internal format
+func (pg *PGBackend) SetURISignKeys(keys []URISignKey) error {
+       pg.uri.fromGeneric(keys)
+       return pg.uri.encrypt(pg.cfg.AESKey)
+}
+
+// GetURLSigKeys converts the backends internal key representation into the 
common representation (URLSigKey)
+func (pg *PGBackend) GetURLSigKeys() ([]URLSigKey, error) {
+       if err := pg.url.decrypt(pg.cfg.AESKey); err != nil {
+               return nil, err
+       }
+       return pg.url.toGeneric(), nil
+}
+
+// SetURLSigKeys takes in keys and converts & encrypts the data into the 
backends internal format
+func (pg *PGBackend) SetURLSigKeys(keys []URLSigKey) error {
+       pg.url.fromGeneric(keys)
+       return pg.url.encrypt(pg.cfg.AESKey)
+}
+
+type pgCommonRecord struct {
+       DataEncrypted []byte
+}
+
+type pgDNSSecRecord struct {
+       Key tc.DNSSECKeysTrafficVault
+       CDN string
+       pgCommonRecord
+}
+type pgDNSSecTable struct {
+       Records []pgDNSSecRecord
+}
+
+func (tbl *pgDNSSecTable) gatherKeys(db *sql.DB) error {
+       sz, err := getSize(db, "dnssec")
+       if err != nil {
+               log.Println("PGDNSSec gatherKeys: unable to determine size of 
dnssec table")
+       }
+       tbl.Records = make([]pgDNSSecRecord, sz)
+
+       rows, err := db.Query("SELECT cdn, data from dnssec")
+       if err != nil {
+               return fmt.Errorf("PGDNSSec gatherKeys: unable to query: %w", 
err)
+       }
+       defer rows.Close()
+       i := 0
+       for rows.Next() {
+               if i > len(tbl.Records)-1 {
+                       return fmt.Errorf("PGDNSSec gatherKeys got more results 
than expected %v", len(tbl.Records))
+               }
+               err := rows.Scan(&tbl.Records[i].CDN, 
&tbl.Records[i].DataEncrypted)
+               if err != nil {
+                       return fmt.Errorf("PGDNSSec gatherKeys unable to scan 
row: %w", err)
+               }
+               i += 1
+       }
+       return nil
+}
+func (tbl *pgDNSSecTable) decrypt(aesKey []byte) error {
+       for i, _ := range tbl.Records {
+               err := decryptInto(aesKey, tbl.Records[i].DataEncrypted, 
&tbl.Records[i].Key)
+               if err != nil {
+                       return fmt.Errorf("unable to decrypt into keys: %w", 
err)
+               }
+       }
+       return nil
+}
+func (tbl *pgDNSSecTable) encrypt(aesKey []byte) error {
+       for i, dns := range tbl.Records {
+               data, err := json.Marshal(&dns.Key)
+               if err != nil {
+                       return fmt.Errorf("encrypt issue marshalling keys: %w", 
err)
+               }
+               dat, err := encrypt(data, aesKey)
+               if err != nil {
+                       return fmt.Errorf("encrypt error: %w", err)
+               }
+               tbl.Records[i].DataEncrypted = dat
+       }
+       return nil
+}
+func (tbl *pgDNSSecTable) toGeneric() []DNSSecKey {
+       keys := make([]DNSSecKey, len(tbl.Records))
+
+       for i, record := range tbl.Records {
+               keys[i] = DNSSecKey{
+                       CDN:                    record.CDN,
+                       DNSSECKeysTrafficVault: record.Key,
+               }
+       }
+
+       return keys
+}
+func (tbl *pgDNSSecTable) fromGeneric(keys []DNSSecKey) {
+       tbl.Records = make([]pgDNSSecRecord, len(keys))
+
+       for i, key := range keys {
+               tbl.Records[i] = pgDNSSecRecord{
+                       Key: key.DNSSECKeysTrafficVault,
+                       CDN: key.CDN,
+                       pgCommonRecord: pgCommonRecord{
+                               DataEncrypted: nil,
+                       },
+               }
+       }
+}
+func (tbl *pgDNSSecTable) validate() []string {
+       for i, record := range tbl.Records {
+               if record.DataEncrypted == nil && len(record.Key) > 0 {
+                       return []string{fmt.Sprintf("DNSSEC Key %v: 
DataEncrypted is blank!", i)}
+               }
+       }
+       return nil
+}
+func (tbl *pgDNSSecTable) insertKeys(db *sql.DB) error {
+       queryBase := "INSERT INTO dnssec (cdn, data) VALUES "
+       stride := 2
+       queryArgs := make([]interface{}, len(tbl.Records)*stride)
+       for i, record := range tbl.Records {
+               j := i * stride
+               queryArgs[j] = record.CDN
+               queryArgs[j+1] = record.DataEncrypted
+       }
+       return insertIntoTable(db, queryBase, stride, queryArgs)
+}
+
+type pgSSLKeyRecord struct {
+       Keys tc.DeliveryServiceSSLKeys
+       pgCommonRecord
+
+       // These records are stored on the table but are duplicated
+       DeliveryService string
+       CDN             string
+}
+type pgSSLKeyTable struct {
+       Records []pgSSLKeyRecord
+}
+
+func (tbl *pgSSLKeyTable) insertKeys(db *sql.DB) error {
+       queryBase := "INSERT INTO sslkey (deliveryservice, data, cdn, version) 
VALUES "
+       duplicateKeys := 2
+       stride := 4
+       queryArgs := make([]interface{}, len(tbl.Records)*stride*duplicateKeys)
+       for i, record := range tbl.Records {
+               j := i * duplicateKeys * stride
+
+               queryArgs[j] = record.DeliveryService
+               queryArgs[j+1] = record.DataEncrypted
+               queryArgs[j+2] = record.CDN
+               queryArgs[j+3] = record.Keys.Version.String()
+
+               queryArgs[j+4] = record.DeliveryService
+               queryArgs[j+5] = record.DataEncrypted
+               queryArgs[j+6] = record.CDN
+               queryArgs[j+7] = "latest"
+       }
+       return insertIntoTable(db, queryBase, 4, queryArgs)
+}
+func (tbl *pgSSLKeyTable) gatherKeys(db *sql.DB) error {
+       sz, err := getSize(db, "sslkey WHERE version='latest'")
+       if err != nil {
+               return fmt.Errorf("PGSSLKey gatherKeys unable to determine size 
of sslkey table: %w", err)
+       }
+       tbl.Records = make([]pgSSLKeyRecord, sz)
+
+       rows, err := db.Query("SELECT data, deliveryservice, cdn from sslkey 
WHERE version = 'latest'")
+       if err != nil {
+               return fmt.Errorf("PGSSLKey gatherKeys unable to query: %w", 
err)
+       }
+       defer rows.Close()
+       i := 0
+       for rows.Next() {
+               if i > len(tbl.Records)-1 {
+                       return errors.New("PGSSLKey gatherKeys: got more 
results than expected")
+               }
+               err := rows.Scan(&tbl.Records[i].DataEncrypted, 
&tbl.Records[i].DeliveryService, &tbl.Records[i].CDN)
+               if err != nil {
+                       return fmt.Errorf("PGSSLKey gatherKeys unable to scan 
row: %w", err)
+               }
+               i += 1
+       }
+       return nil
+}
+func (tbl *pgSSLKeyTable) decrypt(aesKey []byte) error {
+       for i, dns := range tbl.Records {
+               err := decryptInto(aesKey, dns.DataEncrypted, 
&tbl.Records[i].Keys)
+               if err != nil {
+                       return fmt.Errorf("unable to decrypt into keys: %w", 
err)
+               }
+       }
+       return nil
+}
+func (tbl *pgSSLKeyTable) encrypt(aesKey []byte) error {
+       for i, dns := range tbl.Records {
+               data, err := json.Marshal(dns.Keys)
+               if err != nil {
+                       return fmt.Errorf("encrypt issue marshalling keys: %w", 
err)
+               }
+               dat, err := encrypt(data, aesKey)
+               if err != nil {
+                       return fmt.Errorf("encrypt error: %w", err)
+               }
+               tbl.Records[i].DataEncrypted = dat
+       }
+       return nil
+}
+func (tbl *pgSSLKeyTable) toGeneric() []SSLKey {
+       keys := make([]SSLKey, len(tbl.Records))
+
+       for i, record := range tbl.Records {
+               keys[i] = SSLKey{
+                       DeliveryServiceSSLKeys: record.Keys,
+               }
+       }
+       return keys
+}
+func (tbl *pgSSLKeyTable) fromGeneric(keys []SSLKey) {
+       tbl.Records = make([]pgSSLKeyRecord, len(keys))
+
+       for i, key := range keys {
+               tbl.Records[i] = pgSSLKeyRecord{
+                       Keys: key.DeliveryServiceSSLKeys,
+                       pgCommonRecord: pgCommonRecord{
+                               DataEncrypted: nil,
+                       },
+                       DeliveryService: key.DeliveryService,
+                       CDN:             key.CDN,
+               }
+       }
+}
+func (tbl *pgSSLKeyTable) validate() []string {
+       defaultKey := tc.DeliveryServiceSSLKeys{}
+       var errors []string
+       fmtStr := "SSL Key %v: %v"
+       for i, record := range tbl.Records {
+               if record.Keys == defaultKey {
+                       errors = append(errors, fmt.Sprintf(fmtStr, i, "DS SSL 
Keys are default!"))
+               } else if record.Keys.Key == "" {
+                       errors = append(errors, fmt.Sprintf(fmtStr, i, "Key is 
blank!"))
+               } else if record.Keys.CDN == "" {
+                       errors = append(errors, fmt.Sprintf(fmtStr, i, "CDN is 
blank!"))
+               } else if record.Keys.DeliveryService == "" {
+                       errors = append(errors, fmt.Sprintf(fmtStr, i, "DS is 
blank!"))
+               } else if record.DataEncrypted == nil {
+                       errors = append(errors, fmt.Sprintf(fmtStr, i, 
"DataEncrypted is blank!"))
+               } else if record.Keys.Version.String() == "" {
+                       errors = append(errors, fmt.Sprintf(fmtStr, i, "Version 
is blank!"))
+               }
+       }
+       return errors
+}
+
+type pgURLSigKeyRecord struct {
+       Keys            tc.URLSigKeys
+       DeliveryService string
+       pgCommonRecord
+}
+type pgURLSigKeyTable struct {
+       Records []pgURLSigKeyRecord
+}
+
+func (tbl *pgURLSigKeyTable) insertKeys(db *sql.DB) error {
+       queryBase := "INSERT INTO url_sig_key (deliveryservice, data) VALUES "
+       stride := 2
+       queryArgs := make([]interface{}, len(tbl.Records)*stride)
+       for i, record := range tbl.Records {
+               j := i * stride
+               queryArgs[j] = record.DeliveryService
+               queryArgs[j+1] = record.DataEncrypted
+       }
+       return insertIntoTable(db, queryBase, stride, queryArgs)
+}
+func (tbl *pgURLSigKeyTable) gatherKeys(db *sql.DB) error {
+       sz, err := getSize(db, "url_sig_key")
+       if err != nil {
+               log.Println("PGURLSigKey gatherKeys: unable to determine 
url_sig_key table size")
+       }
+       tbl.Records = make([]pgURLSigKeyRecord, sz)
+
+       rows, err := db.Query("SELECT deliveryservice, data from url_sig_key")
+       if err != nil {
+               return fmt.Errorf("PGURLSigKey gatherKeys error creating query: 
%w", err)
+       }
+       defer rows.Close()
+       i := 0
+       for rows.Next() {
+               if i > len(tbl.Records)-1 {
+                       return fmt.Errorf("PGURLSigKey gatherKeys: got more 
results than expected %v", len(tbl.Records))
+               }
+               err := rows.Scan(&tbl.Records[i].DeliveryService, 
&tbl.Records[i].DataEncrypted)
+               if err != nil {
+                       return fmt.Errorf("PGURLSigKey gatherKeys: unable to 
scan row: %w", err)
+               }
+               i += 1
+       }
+       return nil
+}
+func (tbl *pgURLSigKeyTable) decrypt(aesKey []byte) error {
+       for i, sig := range tbl.Records {
+               err := decryptInto(aesKey, sig.DataEncrypted, 
&tbl.Records[i].Keys)
+               if err != nil {
+                       return fmt.Errorf("unable to decrypt into keys: %w", 
err)
+               }
+       }
+       return nil
+}
+func (tbl *pgURLSigKeyTable) encrypt(aesKey []byte) error {
+       for i, sig := range tbl.Records {
+               data, err := json.Marshal(&sig.Keys)
+               if err != nil {
+                       return fmt.Errorf("encrypt issue marshalling keys: %w", 
err)
+               }
+
+               dat, err := encrypt(data, aesKey)
+               if err != nil {
+                       return fmt.Errorf("encrypt error: %w", err)
+               }
+               tbl.Records[i].DataEncrypted = dat
+       }
+       return nil
+}
+func (tbl *pgURLSigKeyTable) toGeneric() []URLSigKey {
+       keys := make([]URLSigKey, len(tbl.Records))
+
+       for i, record := range tbl.Records {
+               keys[i] = URLSigKey{
+                       DeliveryService: record.DeliveryService,
+                       URLSigKeys:      record.Keys,
+               }
+       }
+       return keys
+}
+func (tbl *pgURLSigKeyTable) fromGeneric(keys []URLSigKey) {
+       tbl.Records = make([]pgURLSigKeyRecord, len(keys))
+
+       for i, key := range keys {
+               tbl.Records[i] = pgURLSigKeyRecord{
+                       Keys:            key.URLSigKeys,
+                       DeliveryService: key.DeliveryService,
+                       pgCommonRecord: pgCommonRecord{
+                               DataEncrypted: nil,
+                       },
+               }
+       }
+}
+func (tbl *pgURLSigKeyTable) validate() []string {
+       for i, record := range tbl.Records {
+               if record.DataEncrypted == nil && len(record.Keys) > 0 {
+                       return []string{fmt.Sprintf("URl Sig Key %v: 
DataEncrypted is blank!", i)}
+               }
+       }
+       return nil
+}
+
+type pgURISignKeyRecord struct {
+       Keys            map[string]tc.URISignerKeyset
+       DeliveryService string
+       pgCommonRecord
+}
+type pgURISignKeyTable struct {
+       Records []pgURISignKeyRecord
+}
+
+func (tbl *pgURISignKeyTable) insertKeys(db *sql.DB) error {
+       queryBase := "INSERT INTO uri_signing_key (deliveryservice, data) 
VALUES "

Review comment:
       Riak does indeed just overwrite. I've done this for most of the tables, 
but SSL Keys is only constrained by it's id making comparison quite expensive. 
Do you think it's worth to check for SSL Keys as well? Personally, I think we 
should just skip it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to