shamrickus commented on a change in pull request #5924: URL: https://github.com/apache/trafficcontrol/pull/5924#discussion_r651099277
########## File path: tools/traffic_vault_migrate/postgres.go ########## @@ -0,0 +1,743 @@ +package main + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import ( + "database/sql" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "log" + "strconv" + "strings" + + "github.com/apache/trafficcontrol/lib/go-tc" + util "github.com/apache/trafficcontrol/lib/go-util" + + _ "github.com/lib/pq" +) + +// PGConfig represents the configuration options available to the PG backend +type PGConfig struct { + Host string `json:"host"` + Port string `json:"port"` + User string `json:"user"` + Password string `json:"password"` + SSLMode string `json:"sslmode"` + Database string `json:"database"` + Key string `json:"aesKey"` + AESKey []byte +} + +// PGBackend is the Postgres implementation of TVBackend +type PGBackend struct { + sslKey pgSSLKeyTable + dnssec pgDNSSecTable + uri pgURISignKeyTable + url pgURLSigKeyTable + cfg PGConfig + db *sql.DB +} + +// String returns a high level overview of the backend and its keys +func (pg *PGBackend) String() string { + data := fmt.Sprintf("PG server %v@%v:%v\n", pg.cfg.User, pg.cfg.Host, pg.cfg.Port) + data += fmt.Sprintf("\tSSL Keys: %v\n", len(pg.sslKey.Records)) + data += fmt.Sprintf("\tDNSSec Keys: %v\n", len(pg.dnssec.Records)) + data += fmt.Sprintf("\tURI Keys: %v\n", len(pg.uri.Records)) + data += fmt.Sprintf("\tURL Keys: %v\n", len(pg.url.Records)) + return data +} + +// Name returns the name for this backend +func (pg *PGBackend) Name() string { + return "PG" +} + +// ReadConfig takes in a filename and will read it into the backends config +func (pg *PGBackend) ReadConfig(s string) error { + err := UnmarshalConfig(s, &pg.cfg) + if err != nil { + return err + } + + pg.cfg.AESKey, err = base64.StdEncoding.DecodeString(pg.cfg.Key) + if err != nil { + return fmt.Errorf("unable to decode PG AESKey: %w", err) + } + return nil +} + +// Insert takes the current keys and inserts them into the backend DB +func (pg *PGBackend) Insert() error { + if err := pg.sslKey.insertKeys(pg.db); err != nil { + return err + } + if err := pg.dnssec.insertKeys(pg.db); err != nil { + return err + } + if err := pg.url.insertKeys(pg.db); err != nil { + return err + } + if err := pg.uri.insertKeys(pg.db); err != nil { + return err + } + return nil +} + +// Start initiates the connection to the backend DB +func (pg *PGBackend) Start() error { + sqlStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", pg.cfg.User, pg.cfg.Password, pg.cfg.Host, pg.cfg.Port, pg.cfg.Database, pg.cfg.SSLMode) + db, err := sql.Open("postgres", sqlStr) + if err != nil { + return fmt.Errorf("unable to start PG client: %w", err) + } + + pg.db = db + pg.sslKey = pgSSLKeyTable{} + pg.dnssec = pgDNSSecTable{} + pg.url = pgURLSigKeyTable{} + pg.uri = pgURISignKeyTable{} + + return nil +} + +// ValidateKey validates that the keys are valid (in most cases, certain fields are not null) +func (pg *PGBackend) ValidateKey() []string { + var errors []string + if errs := pg.sslKey.validate(); errs != nil { + errors = append(errors, errs...) + } + if errs := pg.dnssec.validate(); errs != nil { + errors = append(errors, errs...) + } + if errs := pg.uri.validate(); errs != nil { + errors = append(errors, errs...) + } + if errs := pg.url.validate(); errs != nil { + errors = append(errors, errs...) + } + return errors +} + +// Close terminates the connection to the backend DB +func (pg *PGBackend) Close() error { + return pg.db.Close() +} + +// Ping checks the connection to the backend DB +func (pg *PGBackend) Ping() error { + return pg.db.Ping() +} + +// Fetch gets all of the keys from the backend DB +func (pg *PGBackend) Fetch() error { + if err := pg.sslKey.gatherKeys(pg.db); err != nil { + return err + } + + if err := pg.dnssec.gatherKeys(pg.db); err != nil { + return err + } + + if err := pg.url.gatherKeys(pg.db); err != nil { + return err + } + + if err := pg.uri.gatherKeys(pg.db); err != nil { + return err + } + + return nil +} + +// GetSSLKeys converts the backends internal key representation into the common representation (SSLKey) +func (pg *PGBackend) GetSSLKeys() ([]SSLKey, error) { + if err := pg.sslKey.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.sslKey.toGeneric(), nil +} + +// SetSSLKeys takes in keys and converts & encrypts the data into the backends internal format +func (pg *PGBackend) SetSSLKeys(keys []SSLKey) error { + pg.sslKey.fromGeneric(keys) + return pg.sslKey.encrypt(pg.cfg.AESKey) +} + +// GetDNSSecKeys converts the backends internal key representation into the common representation (DNSSecKey) +func (pg *PGBackend) GetDNSSecKeys() ([]DNSSecKey, error) { + if err := pg.dnssec.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.dnssec.toGeneric(), nil +} + +// SetDNSSecKeys takes in keys and converts & encrypts the data into the backends internal format +func (pg *PGBackend) SetDNSSecKeys(keys []DNSSecKey) error { + pg.dnssec.fromGeneric(keys) + return pg.dnssec.encrypt(pg.cfg.AESKey) +} + +// GetURISignKeys converts the pg internal key representation into the common representation (URISignKey) +func (pg *PGBackend) GetURISignKeys() ([]URISignKey, error) { + if err := pg.uri.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.uri.toGeneric(), nil +} + +// SetURISignKeys takes in keys and converts & encrypts the data into the backends internal format +func (pg *PGBackend) SetURISignKeys(keys []URISignKey) error { + pg.uri.fromGeneric(keys) + return pg.uri.encrypt(pg.cfg.AESKey) +} + +// GetURLSigKeys converts the backends internal key representation into the common representation (URLSigKey) +func (pg *PGBackend) GetURLSigKeys() ([]URLSigKey, error) { + if err := pg.url.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.url.toGeneric(), nil +} + +// SetURLSigKeys takes in keys and converts & encrypts the data into the backends internal format +func (pg *PGBackend) SetURLSigKeys(keys []URLSigKey) error { + pg.url.fromGeneric(keys) + return pg.url.encrypt(pg.cfg.AESKey) +} + +type pgCommonRecord struct { + DataEncrypted []byte +} + +type pgDNSSecRecord struct { + Key tc.DNSSECKeysTrafficVault + CDN string + pgCommonRecord +} +type pgDNSSecTable struct { + Records []pgDNSSecRecord +} + +func (tbl *pgDNSSecTable) gatherKeys(db *sql.DB) error { + sz, err := getSize(db, "dnssec") + if err != nil { + log.Println("PGDNSSec gatherKeys: unable to determine size of dnssec table") + } + tbl.Records = make([]pgDNSSecRecord, sz) + + rows, err := db.Query("SELECT cdn, data from dnssec") + if err != nil { + return fmt.Errorf("PGDNSSec gatherKeys: unable to query: %w", err) + } + defer rows.Close() + i := 0 + for rows.Next() { + if i > len(tbl.Records)-1 { + return fmt.Errorf("PGDNSSec gatherKeys got more results than expected %v", len(tbl.Records)) + } + err := rows.Scan(&tbl.Records[i].CDN, &tbl.Records[i].DataEncrypted) + if err != nil { + return fmt.Errorf("PGDNSSec gatherKeys unable to scan row: %w", err) + } + i += 1 + } + return nil +} +func (tbl *pgDNSSecTable) decrypt(aesKey []byte) error { + for i, _ := range tbl.Records { + err := decryptInto(aesKey, tbl.Records[i].DataEncrypted, &tbl.Records[i].Key) + if err != nil { + return fmt.Errorf("unable to decrypt into keys: %w", err) + } + } + return nil +} +func (tbl *pgDNSSecTable) encrypt(aesKey []byte) error { + for i, dns := range tbl.Records { + data, err := json.Marshal(&dns.Key) + if err != nil { + return fmt.Errorf("encrypt issue marshalling keys: %w", err) + } + dat, err := encrypt(data, aesKey) + if err != nil { + return fmt.Errorf("encrypt error: %w", err) + } + tbl.Records[i].DataEncrypted = dat + } + return nil +} +func (tbl *pgDNSSecTable) toGeneric() []DNSSecKey { + keys := make([]DNSSecKey, len(tbl.Records)) + + for i, record := range tbl.Records { + keys[i] = DNSSecKey{ + CDN: record.CDN, + DNSSECKeysTrafficVault: record.Key, + } + } + + return keys +} +func (tbl *pgDNSSecTable) fromGeneric(keys []DNSSecKey) { + tbl.Records = make([]pgDNSSecRecord, len(keys)) + + for i, key := range keys { + tbl.Records[i] = pgDNSSecRecord{ + Key: key.DNSSECKeysTrafficVault, + CDN: key.CDN, + pgCommonRecord: pgCommonRecord{ + DataEncrypted: nil, + }, + } + } +} +func (tbl *pgDNSSecTable) validate() []string { + for i, record := range tbl.Records { + if record.DataEncrypted == nil && len(record.Key) > 0 { + return []string{fmt.Sprintf("DNSSEC Key %v: DataEncrypted is blank!", i)} + } + } + return nil +} +func (tbl *pgDNSSecTable) insertKeys(db *sql.DB) error { + queryBase := "INSERT INTO dnssec (cdn, data) VALUES " + stride := 2 + queryArgs := make([]interface{}, len(tbl.Records)*stride) + for i, record := range tbl.Records { + j := i * stride + queryArgs[j] = record.CDN + queryArgs[j+1] = record.DataEncrypted + } + return insertIntoTable(db, queryBase, stride, queryArgs) +} + +type pgSSLKeyRecord struct { + Keys tc.DeliveryServiceSSLKeys + pgCommonRecord + + // These records are stored on the table but are duplicated + DeliveryService string + CDN string +} +type pgSSLKeyTable struct { + Records []pgSSLKeyRecord +} + +func (tbl *pgSSLKeyTable) insertKeys(db *sql.DB) error { + queryBase := "INSERT INTO sslkey (deliveryservice, data, cdn, version) VALUES " + duplicateKeys := 2 + stride := 4 + queryArgs := make([]interface{}, len(tbl.Records)*stride*duplicateKeys) + for i, record := range tbl.Records { + j := i * duplicateKeys * stride + + queryArgs[j] = record.DeliveryService + queryArgs[j+1] = record.DataEncrypted + queryArgs[j+2] = record.CDN + queryArgs[j+3] = record.Keys.Version.String() + + queryArgs[j+4] = record.DeliveryService + queryArgs[j+5] = record.DataEncrypted + queryArgs[j+6] = record.CDN + queryArgs[j+7] = "latest" + } + return insertIntoTable(db, queryBase, 4, queryArgs) +} +func (tbl *pgSSLKeyTable) gatherKeys(db *sql.DB) error { + sz, err := getSize(db, "sslkey WHERE version='latest'") Review comment: Hmmm...I was misreading how we inserted ssl keys, you are correct. I'll remove the filter. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
