rawlinp commented on a change in pull request #5924: URL: https://github.com/apache/trafficcontrol/pull/5924#discussion_r653894403
########## File path: traffic_ops/app/db/traffic_vault_migrate/postgres.go ########## @@ -0,0 +1,744 @@ +package main + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import ( + "database/sql" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/apache/trafficcontrol/lib/go-log" + "github.com/apache/trafficcontrol/lib/go-tc" + util "github.com/apache/trafficcontrol/lib/go-util" + + _ "github.com/lib/pq" +) + +// PGConfig represents the configuration options available to the PG backend. +type PGConfig struct { + Host string `json:"host"` + Port string `json:"port"` + User string `json:"user"` + Password string `json:"password"` + SSLMode string `json:"sslmode"` + Database string `json:"database"` + Key string `json:"aesKey"` Review comment: nit: maybe this should be named `KeyBase64` to differentiate it more from `AESKey` ########## File path: traffic_ops/app/db/traffic_vault_migrate/traffic_vault_migrate.go ########## @@ -0,0 +1,611 @@ +package main + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + stdlog "log" + "os" + "reflect" + "sort" + "strings" + + "github.com/pborman/getopt/v2" + + "github.com/apache/trafficcontrol/lib/go-log" + "github.com/apache/trafficcontrol/lib/go-tc" +) + +var ( + fromType string + toType string + fromCfgPath string + toCfgPath string + logCfgPath string + keyFile string + dry bool + compare bool + noConfirm bool + dump bool + logLevel string + + cfg config = config{ + LogLocationError: log.LogLocationStderr, + LogLocationWarning: log.LogLocationStdout, + LogLocationInfo: log.LogLocationStdout, + LogLocationDebug: log.LogLocationNull, + LogLocationEvent: log.LogLocationNull, + } + riakBE RiakBackend = RiakBackend{} + pgBE PGBackend = PGBackend{} +) + +func init() { + fromTypePtr := getopt.StringLong("fromType", 't', riakBE.Name(), fmt.Sprintf("From server types (%v)", strings.Join(supportedTypes(), "|"))) + if fromTypePtr == nil { + stdlog.Fatal("unable to load fromType") + } + fromType = *fromTypePtr + + toTypePtr := getopt.StringLong("toType", 'o', pgBE.Name(), fmt.Sprintf("From server types (%v)", strings.Join(supportedTypes(), "|"))) + if toTypePtr == nil { + stdlog.Fatal("unable to load toType") + } + toType = *toTypePtr + + toCfgPtr := getopt.StringLong("toCfgPath", 'g', "pg.json", "To server config file") + if toCfgPtr == nil { + stdlog.Fatal("unable to load toCfg") + } + toCfgPath = *toCfgPtr + + fromCfgPtr := getopt.StringLong("fromCfgPath", 'f', "riak.json", "From server config file") + if fromCfgPtr == nil { + stdlog.Fatal("unable to load fromCfg") + } + fromCfgPath = *fromCfgPtr + + getopt.FlagLong(&dry, "dry", 'r', "Do not perform writes"). + SetOptional(). + SetFlag(). + SetGroup("no_insert") + + getopt.FlagLong(&compare, "compare", 'c', "Compare to and from server records"). + SetOptional(). + SetFlag(). + SetGroup("no_insert") + + getopt.FlagLong(&noConfirm, "noConfirm", 'm', "Requires confirmation before inserting records"). + SetFlag() + + getopt.FlagLong(&dump, "dump", 'd', "Write keys (from 'from' server) to disk"). + SetOptional(). + SetGroup("disk_bck"). + SetFlag() + + getopt.FlagLong(&keyFile, "fill", 'i', "Insert data into `to` server with data this directory"). + SetOptional(). + SetGroup("disk_bck") + + getopt.FlagLong(&logCfgPath, "logCfg", 'l', "Log configuration file"). + SetOptional(). + SetGroup("log") + + getopt.FlagLong(&logLevel, "logLevel", 'e', "Print everything at above specified log level (error|warning|info|debug|event)"). + SetOptional(). + SetGroup("log") +} + +// supportBackends returns the backends available in this tool. +func supportedBackends() []TVBackend { + return []TVBackend{ + &riakBE, &pgBE, + } +} + +func main() { + getopt.ParseV2() + + initConfig() + + var fromSrv TVBackend + var toSrv TVBackend + + importData := keyFile != "" + toSrvUsed := !dump && !dry || keyFile != "" + + if !importData { + log.Infof("Initiating fromSrv %s...\n", fromType) + if !validateType(fromType) { + log.Errorln("Unknown fromType " + fromType) + os.Exit(1) + } + fromSrv = getBackendFromType(fromType) + if err := fromSrv.ReadConfigFile(fromCfgPath); err != nil { + log.Errorf("Unable to read fromSrv cfg: %v", err) + os.Exit(1) + } + + if err := fromSrv.Start(); err != nil { + log.Errorf("issue starting fromSrv: %v", err) + os.Exit(1) + } + defer log.Close(fromSrv, "closing fromSrv") + + if err := fromSrv.Ping(); err != nil { + log.Errorf("Unable to ping fromSrv: %v", err) + os.Exit(1) + } + } + + if toSrvUsed { + log.Infof("Initiating toSrv %s...\n", toType) + if !validateType(toType) { + log.Errorln("Unknown toType " + toType) + os.Exit(1) + } + toSrv = getBackendFromType(toType) + + if err := toSrv.ReadConfigFile(toCfgPath); err != nil { + log.Errorf("Unable to read toSrv cfg: %v", err) + os.Exit(1) + } + + if err := toSrv.Start(); err != nil { + log.Errorf("issue starting toSrv: %v", err) + os.Exit(1) + } + defer log.Close(toSrv, "closing toSrv") + + if err := toSrv.Ping(); err != nil { + log.Errorf("Unable to ping toSrv: %v", err) + os.Exit(1) + } + } + + var fromSecret Secrets + if !importData { + var err error + log.Infof("Fetching data from %s...\n", fromSrv.Name()) + if err = fromSrv.Fetch(); err != nil { + log.Errorf("Unable to fetch fromSrv data: %v", err) + os.Exit(1) + } + + if fromSecret, err = GetKeys(fromSrv); err != nil { + log.Errorln(err) + os.Exit(1) + } + + if err := Validate(fromSrv); err != nil { + log.Errorln(err) + os.Exit(1) + } + + } else { + err := fromSecret.fill(keyFile) + if err != nil { + log.Errorln(err) + os.Exit(1) + } + } + + if dump { + log.Infof("Dumping data from %s...\n", fromSrv.Name()) + fromSecret.dump("dump") + return + } + + if compare { + log.Infof("Fetching data from %s...\n", toSrv.Name()) + if err := toSrv.Fetch(); err != nil { + log.Errorf("Unable to fetch toSrv data: %v\n", err) + os.Exit(1) + } + + toSecret, err := GetKeys(toSrv) + if err != nil { + log.Errorln(err) + os.Exit(1) + } + log.Infoln("Validating " + toSrv.Name()) + if err := toSrv.ValidateKey(); err != nil && len(err) > 0 { + log.Errorln(strings.Join(err, "\n")) + os.Exit(1) + } + + fromSecret.sort() + toSecret.sort() + + if !importData { + log.Infoln(fromSrv.String()) + } else { + log.Infof("Disk backup:\n\tSSL Keys: %d\n\tDNSSec Keys: %d\n\tURI Keys: %d\n\tURL Keys: %d\n", len(fromSecret.sslkeys), len(fromSecret.dnssecKeys), len(fromSecret.uriKeys), len(fromSecret.urlKeys)) + } + log.Infoln(toSrv.String()) + + if !reflect.DeepEqual(fromSecret.sslkeys, toSecret.sslkeys) { Review comment: I just want to say this is awesome -- I'm glad you thought of adding comparison functionality 😄 . This will definitely help give us the warm fuzzies. ########## File path: traffic_ops/app/db/traffic_vault_migrate/traffic_vault_migrate.go ########## @@ -0,0 +1,611 @@ +package main + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + stdlog "log" + "os" + "reflect" + "sort" + "strings" + + "github.com/pborman/getopt/v2" + + "github.com/apache/trafficcontrol/lib/go-log" + "github.com/apache/trafficcontrol/lib/go-tc" +) + +var ( + fromType string + toType string + fromCfgPath string + toCfgPath string + logCfgPath string + keyFile string + dry bool + compare bool + noConfirm bool + dump bool + logLevel string + + cfg config = config{ + LogLocationError: log.LogLocationStderr, + LogLocationWarning: log.LogLocationStdout, + LogLocationInfo: log.LogLocationStdout, + LogLocationDebug: log.LogLocationNull, + LogLocationEvent: log.LogLocationNull, + } + riakBE RiakBackend = RiakBackend{} + pgBE PGBackend = PGBackend{} +) + +func init() { + fromTypePtr := getopt.StringLong("fromType", 't', riakBE.Name(), fmt.Sprintf("From server types (%v)", strings.Join(supportedTypes(), "|"))) + if fromTypePtr == nil { + stdlog.Fatal("unable to load fromType") + } + fromType = *fromTypePtr + + toTypePtr := getopt.StringLong("toType", 'o', pgBE.Name(), fmt.Sprintf("From server types (%v)", strings.Join(supportedTypes(), "|"))) + if toTypePtr == nil { + stdlog.Fatal("unable to load toType") + } + toType = *toTypePtr + + toCfgPtr := getopt.StringLong("toCfgPath", 'g', "pg.json", "To server config file") + if toCfgPtr == nil { + stdlog.Fatal("unable to load toCfg") + } + toCfgPath = *toCfgPtr + + fromCfgPtr := getopt.StringLong("fromCfgPath", 'f', "riak.json", "From server config file") + if fromCfgPtr == nil { + stdlog.Fatal("unable to load fromCfg") + } + fromCfgPath = *fromCfgPtr + + getopt.FlagLong(&dry, "dry", 'r', "Do not perform writes"). + SetOptional(). + SetFlag(). + SetGroup("no_insert") + + getopt.FlagLong(&compare, "compare", 'c', "Compare to and from server records"). + SetOptional(). + SetFlag(). + SetGroup("no_insert") + + getopt.FlagLong(&noConfirm, "noConfirm", 'm', "Requires confirmation before inserting records"). + SetFlag() + + getopt.FlagLong(&dump, "dump", 'd', "Write keys (from 'from' server) to disk"). + SetOptional(). + SetGroup("disk_bck"). + SetFlag() + + getopt.FlagLong(&keyFile, "fill", 'i', "Insert data into `to` server with data this directory"). + SetOptional(). + SetGroup("disk_bck") + + getopt.FlagLong(&logCfgPath, "logCfg", 'l', "Log configuration file"). + SetOptional(). + SetGroup("log") + + getopt.FlagLong(&logLevel, "logLevel", 'e', "Print everything at above specified log level (error|warning|info|debug|event)"). + SetOptional(). + SetGroup("log") +} + +// supportBackends returns the backends available in this tool. +func supportedBackends() []TVBackend { + return []TVBackend{ + &riakBE, &pgBE, + } +} + +func main() { + getopt.ParseV2() + + initConfig() + + var fromSrv TVBackend + var toSrv TVBackend + + importData := keyFile != "" + toSrvUsed := !dump && !dry || keyFile != "" Review comment: nit: `keyFile != ""` can be replaced w/ `importData`, and this would probably be more readable with parentheses: `(!dump && !dry) || importData` -- assuming that's the precedence for this ########## File path: traffic_ops/app/db/traffic_vault_migrate/postgres.go ########## @@ -0,0 +1,744 @@ +package main + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import ( + "database/sql" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/apache/trafficcontrol/lib/go-log" + "github.com/apache/trafficcontrol/lib/go-tc" + util "github.com/apache/trafficcontrol/lib/go-util" + + _ "github.com/lib/pq" +) + +// PGConfig represents the configuration options available to the PG backend. +type PGConfig struct { + Host string `json:"host"` + Port string `json:"port"` + User string `json:"user"` + Password string `json:"password"` + SSLMode string `json:"sslmode"` + Database string `json:"database"` + Key string `json:"aesKey"` + AESKey []byte +} + +// PGBackend is the Postgres implementation of TVBackend. +type PGBackend struct { + sslKey pgSSLKeyTable + dnssec pgDNSSecTable + uriSigningKeys pgURISignKeyTable + urlSigKeys pgURLSigKeyTable + cfg PGConfig + db *sql.DB +} + +// String returns a high level overview of the backend and its keys. +func (pg *PGBackend) String() string { + data := fmt.Sprintf("PG server %s@%s:%s\n", pg.cfg.User, pg.cfg.Host, pg.cfg.Port) + data += fmt.Sprintf("\tSSL Keys: %d\n", len(pg.sslKey.Records)) + data += fmt.Sprintf("\tDNSSec Keys: %d\n", len(pg.dnssec.Records)) + data += fmt.Sprintf("\tURI Signing Keys: %d\n", len(pg.uriSigningKeys.Records)) + data += fmt.Sprintf("\tURL Sig Keys: %d\n", len(pg.urlSigKeys.Records)) + return data +} + +// Name returns the name for this backend. +func (pg *PGBackend) Name() string { + return "PG" +} + +// ReadConfigFile takes in a filename and will read it into the backends config. +func (pg *PGBackend) ReadConfigFile(configFile string) error { + var err error + if err = UnmarshalConfig(configFile, &pg.cfg); err != nil { + return err + } + + if pg.cfg.AESKey, err = base64.StdEncoding.DecodeString(pg.cfg.Key); err != nil { + return fmt.Errorf("unable to decode PG AESKey '%s': %w", pg.cfg.Key, err) + } + + if err = util.ValidateAESKey(pg.cfg.AESKey); err != nil { + return fmt.Errorf("unable to validate PG AESKey '%s'", pg.cfg.Key) + } + return nil +} + +// Insert takes the current keys and inserts them into the backend DB. +func (pg *PGBackend) Insert() error { + if err := pg.sslKey.insertKeys(pg.db); err != nil { + return err + } + if err := pg.dnssec.insertKeys(pg.db); err != nil { + return err + } + if err := pg.urlSigKeys.insertKeys(pg.db); err != nil { + return err + } + if err := pg.uriSigningKeys.insertKeys(pg.db); err != nil { + return err + } + return nil +} + +// Start initiates the connection to the backend DB. +func (pg *PGBackend) Start() error { + sqlStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", pg.cfg.User, pg.cfg.Password, pg.cfg.Host, pg.cfg.Port, pg.cfg.Database, pg.cfg.SSLMode) + db, err := sql.Open("postgres", sqlStr) + if err != nil { + sqlStr = strings.Replace(sqlStr, pg.cfg.Password, "*", 1) + return fmt.Errorf("unable to start PG client with connection string '%s': %w", sqlStr, err) + } + + pg.db = db + pg.sslKey = pgSSLKeyTable{} + pg.dnssec = pgDNSSecTable{} + pg.urlSigKeys = pgURLSigKeyTable{} + pg.uriSigningKeys = pgURISignKeyTable{} + + return nil +} + +// ValidateKey validates that the keys are valid (in most cases, certain fields are not null). +func (pg *PGBackend) ValidateKey() []string { + var allErrs []string + if errs := pg.sslKey.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + if errs := pg.dnssec.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + if errs := pg.uriSigningKeys.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + if errs := pg.urlSigKeys.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + return allErrs +} + +// Close terminates the connection to the backend DB. +func (pg *PGBackend) Close() error { + return pg.db.Close() +} + +// Ping checks the connection to the backend DB. +func (pg *PGBackend) Ping() error { + return pg.db.Ping() +} + +// Fetch gets all of the keys from the backend DB. +func (pg *PGBackend) Fetch() error { + if err := pg.sslKey.gatherKeys(pg.db); err != nil { + return err + } + + if err := pg.dnssec.gatherKeys(pg.db); err != nil { + return err + } + + if err := pg.urlSigKeys.gatherKeys(pg.db); err != nil { + return err + } + + if err := pg.uriSigningKeys.gatherKeys(pg.db); err != nil { + return err + } + + return nil +} + +// GetSSLKeys converts the backends internal key representation into the common representation (SSLKey). +func (pg *PGBackend) GetSSLKeys() ([]SSLKey, error) { + if err := pg.sslKey.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.sslKey.toGeneric(), nil +} + +// SetSSLKeys takes in keys and converts & encrypts the data into the backends internal format. +func (pg *PGBackend) SetSSLKeys(keys []SSLKey) error { + pg.sslKey.fromGeneric(keys) + return pg.sslKey.encrypt(pg.cfg.AESKey) +} + +// GetDNSSecKeys converts the backends internal key representation into the common representation (DNSSecKey). +func (pg *PGBackend) GetDNSSecKeys() ([]DNSSecKey, error) { + if err := pg.dnssec.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.dnssec.toGeneric(), nil +} + +// SetDNSSecKeys takes in keys and converts & encrypts the data into the backends internal format. +func (pg *PGBackend) SetDNSSecKeys(keys []DNSSecKey) error { + pg.dnssec.fromGeneric(keys) + return pg.dnssec.encrypt(pg.cfg.AESKey) +} + +// GetURISignKeys converts the pg internal key representation into the common representation (URISignKey). +func (pg *PGBackend) GetURISignKeys() ([]URISignKey, error) { + if err := pg.uriSigningKeys.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.uriSigningKeys.toGeneric(), nil +} + +// SetURISignKeys takes in keys and converts & encrypts the data into the backends internal format. +func (pg *PGBackend) SetURISignKeys(keys []URISignKey) error { + pg.uriSigningKeys.fromGeneric(keys) + return pg.uriSigningKeys.encrypt(pg.cfg.AESKey) +} + +// GetURLSigKeys converts the backends internal key representation into the common representation (URLSigKey). +func (pg *PGBackend) GetURLSigKeys() ([]URLSigKey, error) { + if err := pg.urlSigKeys.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.urlSigKeys.toGeneric(), nil +} + +// SetURLSigKeys takes in keys and converts & encrypts the data into the backends internal format. +func (pg *PGBackend) SetURLSigKeys(keys []URLSigKey) error { + pg.urlSigKeys.fromGeneric(keys) + return pg.urlSigKeys.encrypt(pg.cfg.AESKey) +} + +type pgCommonRecord struct { + DataEncrypted []byte +} + +type pgDNSSecRecord struct { + Key tc.DNSSECKeysTrafficVault + CDN string + pgCommonRecord +} +type pgDNSSecTable struct { + Records []pgDNSSecRecord +} + +func (tbl *pgDNSSecTable) gatherKeys(db *sql.DB) error { + sz, err := getSize(db, "dnssec") + if err != nil { + log.Errorln("PGDNSSec gatherKeys: unable to determine size of dnssec table") + } + tbl.Records = make([]pgDNSSecRecord, sz) + + query := "SELECT cdn, data from dnssec" + rows, err := db.Query(query) + if err != nil { + return fmt.Errorf("PGDNSSec gatherKeys: unable to run query '%s': %w", query, err) + } + defer log.Close(rows, "closing dnssec query") + i := 0 + for rows.Next() { + if i > len(tbl.Records)-1 { + return fmt.Errorf("PGDNSSec gatherKeys got more results than expected %d", len(tbl.Records)) + } + if err := rows.Scan(&tbl.Records[i].CDN, &tbl.Records[i].DataEncrypted); err != nil { + return fmt.Errorf("PGDNSSec gatherKeys unable to scan row: %w", err) + } + i += 1 + } + return nil +} +func (tbl *pgDNSSecTable) decrypt(aesKey []byte) error { + for i, _ := range tbl.Records { + if err := decryptInto(aesKey, tbl.Records[i].DataEncrypted, &tbl.Records[i].Key); err != nil { + return fmt.Errorf("unable to decrypt into keys: %w", err) + } + } + return nil +} +func (tbl *pgDNSSecTable) encrypt(aesKey []byte) error { + for i, dns := range tbl.Records { + data, err := json.Marshal(&dns.Key) + if err != nil { + return fmt.Errorf("encrypt issue marshalling keys: %w", err) + } + dat, err := encrypt(data, aesKey) + if err != nil { + return fmt.Errorf("encrypt error: %w", err) + } + tbl.Records[i].DataEncrypted = dat + } + return nil +} +func (tbl *pgDNSSecTable) toGeneric() []DNSSecKey { + keys := make([]DNSSecKey, len(tbl.Records)) + + for i, record := range tbl.Records { + keys[i] = DNSSecKey{ + CDN: record.CDN, + DNSSECKeysTrafficVault: record.Key, + } + } + + return keys +} +func (tbl *pgDNSSecTable) fromGeneric(keys []DNSSecKey) { + tbl.Records = make([]pgDNSSecRecord, len(keys)) + + for i, key := range keys { + tbl.Records[i] = pgDNSSecRecord{ + Key: key.DNSSECKeysTrafficVault, + CDN: key.CDN, + pgCommonRecord: pgCommonRecord{ + DataEncrypted: nil, + }, + } + } +} +func (tbl *pgDNSSecTable) validate() []string { + for _, record := range tbl.Records { + if record.DataEncrypted == nil && len(record.Key) > 0 { + return []string{fmt.Sprintf("DNSSEC Key CDN '%s': DataEncrypted is blank!", record.CDN)} + } + } + return nil +} +func (tbl *pgDNSSecTable) insertKeys(db *sql.DB) error { + queryBase := "INSERT INTO dnssec (cdn, data) VALUES %s ON CONFLICT (cdn) DO UPDATE SET data = EXCLUDED.data" + stride := 2 + queryArgs := make([]interface{}, len(tbl.Records)*stride) + for i, record := range tbl.Records { + j := i * stride + queryArgs[j] = record.CDN + queryArgs[j+1] = record.DataEncrypted + } + return insertIntoTable(db, queryBase, stride, queryArgs) +} + +type pgSSLKeyRecord struct { + Keys tc.DeliveryServiceSSLKeys + pgCommonRecord + + // These records are stored on the table but are duplicated + DeliveryService string + CDN string + Version string +} +type pgSSLKeyTable struct { + Records []pgSSLKeyRecord +} + +func (tbl *pgSSLKeyTable) insertKeys(db *sql.DB) error { + queryBase := "INSERT INTO sslkey (deliveryservice, data, cdn, version) VALUES %s ON CONFLICT (deliveryservice,cdn,version) DO UPDATE SET data = EXCLUDED.data" + stride := 4 + queryArgs := make([]interface{}, len(tbl.Records)*stride) + for i, record := range tbl.Records { + j := i * stride + + queryArgs[j] = record.DeliveryService + queryArgs[j+1] = record.DataEncrypted + queryArgs[j+2] = record.CDN + queryArgs[j+3] = record.Version + } + return insertIntoTable(db, queryBase, 4, queryArgs) +} +func (tbl *pgSSLKeyTable) gatherKeys(db *sql.DB) error { + sz, err := getSize(db, "sslkey") + if err != nil { + return fmt.Errorf("PGSSLKey gatherKeys unable to determine size of sslkey table: %w", err) + } + tbl.Records = make([]pgSSLKeyRecord, sz) + + query := "SELECT data, deliveryservice, cdn, version from sslkey" + rows, err := db.Query(query) + if err != nil { + return fmt.Errorf("PGSSLKey gatherKeys unable to run query '%s': %w", query, err) + } + defer log.Close(rows, "closing sslkey query") + i := 0 + for rows.Next() { + if i > len(tbl.Records)-1 { + return fmt.Errorf("PGSSLKey gatherKeys: got more results than expected") + } + if err := rows.Scan(&tbl.Records[i].DataEncrypted, &tbl.Records[i].DeliveryService, &tbl.Records[i].CDN, &tbl.Records[i].Version); err != nil { + return fmt.Errorf("PGSSLKey gatherKeys unable to scan %d row: %w", i, err) + } + i += 1 + } + return nil +} +func (tbl *pgSSLKeyTable) decrypt(aesKey []byte) error { + for i, dns := range tbl.Records { Review comment: nit: `dns`? Must be a copy-paste remnant ########## File path: traffic_ops/app/db/traffic_vault_migrate/riak.go ########## @@ -0,0 +1,749 @@ +package main + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import ( + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/basho/riak-go-client" + + "github.com/apache/trafficcontrol/lib/go-rfc" + "github.com/apache/trafficcontrol/lib/go-tc" +) + +const ( + BUCKET_SSL = "ssl" + BUCKET_DNSSEC = "dnssec" + BUCKET_URL_SIG = "url_sig_keys" + BUCKET_URI_SIG = "cdn_uri_sig_keys" + + INDEX_SSL = "sslkeys" + + SCHEMA_RIAK_KEY = "_yz_rk" + SCHEMA_RIAK_BUCKET = "_yz_rb" +) + +var ( + SCHEMA_SSL_FIELDS = [...]string{SCHEMA_RIAK_KEY, SCHEMA_RIAK_BUCKET} +) + +// RiakConfig represents the configuration options available to the Riak backend. +type RiakConfig struct { + Host string `json:"host"` + Port string `json:"port"` + User string `json:"user"` + Password string `json:"password"` + Insecure bool `json:"insecure"` + TLSVersionRaw string `json:"tlsVersion"` + // Timeout is the number of seconds each command should use. + Timeout int `json:"timeout"` + + TLSVersion uint16 `json:"-"` +} + +// RiakBackend is the Riak implementation of TVBackend. +type RiakBackend struct { + sslKeys riakSSLKeyTable + dnssecKeys riakDNSSecKeyTable + uriSigningKeys riakURISignKeyTable + urlSigKeys riakURLSigKeyTable + cfg RiakConfig + cluster *riak.Cluster +} + +// String returns a high level overview of the backend and its keys. +func (rb *RiakBackend) String() string { + data := fmt.Sprintf("Riak server %s@%s:%s\n", rb.cfg.User, rb.cfg.Host, rb.cfg.Port) + data += fmt.Sprintf("\tSSL Keys: %d\n", len(rb.sslKeys.Records)) + data += fmt.Sprintf("\tDNSSec Keys: %d\n", len(rb.dnssecKeys.Records)) + data += fmt.Sprintf("\tURI Signing Keys: %d\n", len(rb.uriSigningKeys.Records)) + data += fmt.Sprintf("\tURL Sig Keys: %d\n", len(rb.urlSigKeys.Records)) + return data +} + +// Name returns the name for this backend. +func (rb *RiakBackend) Name() string { + return "Riak" +} + +// ReadConfigFile takes in a filename and will read it into the backends config. +func (rb *RiakBackend) ReadConfigFile(configFile string) error { + err := UnmarshalConfig(configFile, &rb.cfg) + if err != nil { + return err + } + + switch rb.cfg.TLSVersionRaw { + case "10": + rb.cfg.TLSVersion = tls.VersionTLS10 + case "11": + rb.cfg.TLSVersion = tls.VersionTLS11 + case "12": + rb.cfg.TLSVersion = tls.VersionTLS12 + case "13": + rb.cfg.TLSVersion = tls.VersionTLS13 + default: + return fmt.Errorf("unknown tls version " + rb.cfg.TLSVersionRaw) + } + return nil +} + +// Insert takes the current keys and inserts them into the backend DB. +func (rb *RiakBackend) Insert() error { + if err := rb.sslKeys.insertKeys(rb.cluster, rb.cfg.Timeout); err != nil { + return err + } + if err := rb.dnssecKeys.insertKeys(rb.cluster, rb.cfg.Timeout); err != nil { + return err + } + if err := rb.urlSigKeys.insertKeys(rb.cluster, rb.cfg.Timeout); err != nil { + return err + } + if err := rb.uriSigningKeys.insertKeys(rb.cluster, rb.cfg.Timeout); err != nil { + return err + } + return nil +} + +// ValidateKey validates that the keys are valid (in most cases, certain fields are not null). +func (rb *RiakBackend) ValidateKey() []string { + allErrs := []string{} + if errs := rb.sslKeys.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + if errs := rb.dnssecKeys.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + if errs := rb.uriSigningKeys.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + if errs := rb.urlSigKeys.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + + return allErrs +} + +// SetSSLKeys takes in keys and converts & encrypts the data into the backends internal format. +func (rb *RiakBackend) SetSSLKeys(keys []SSLKey) error { + rb.sslKeys.fromGeneric(keys) + return nil +} + +// SetDNSSecKeys takes in keys and converts & encrypts the data into the backends internal format. +func (rb *RiakBackend) SetDNSSecKeys(keys []DNSSecKey) error { + rb.dnssecKeys.fromGeneric(keys) + return nil +} + +// SetURISignKeys takes in keys and converts & encrypts the data into the backends internal format. +func (rb *RiakBackend) SetURISignKeys(keys []URISignKey) error { + rb.uriSigningKeys.fromGeneric(keys) + return nil +} + +// SetURLSigKeys takes in keys and converts & encrypts the data into the backends internal format. +func (rb *RiakBackend) SetURLSigKeys(keys []URLSigKey) error { + rb.urlSigKeys.fromGeneric(keys) + return nil +} + +// Start initiates the connection to the backend DB. +func (rb *RiakBackend) Start() error { + tlsConfig := &tls.Config{ + InsecureSkipVerify: rb.cfg.Insecure, + MaxVersion: rb.cfg.TLSVersion, + } + auth := &riak.AuthOptions{ + User: rb.cfg.User, + Password: rb.cfg.Password, + TlsConfig: tlsConfig, + } + + cluster, err := getRiakCluster(rb.cfg, auth) + if err != nil { + return err + } + if err := cluster.Start(); err != nil { + return fmt.Errorf("unable to start riak cluster: %w", err) + } + + rb.cluster = cluster + rb.sslKeys = riakSSLKeyTable{} + rb.dnssecKeys = riakDNSSecKeyTable{} + rb.urlSigKeys = riakURLSigKeyTable{} + rb.uriSigningKeys = riakURISignKeyTable{} + return nil +} + +// Close terminates the connection to the backend DB. +func (rb *RiakBackend) Close() error { + if err := rb.cluster.Stop(); err != nil { + return err + } + return nil +} + +// Ping checks the connection to the backend DB. +func (rb *RiakBackend) Ping() error { + return ping(rb.cluster) +} + +// GetSSLKeys converts the backends internal key representation into the common representation (SSLKey). +func (rb *RiakBackend) GetSSLKeys() ([]SSLKey, error) { + return rb.sslKeys.toGeneric(), nil +} + +// GetDNSSecKeys converts the backends internal key representation into the common representation (DNSSecKey). +func (rb *RiakBackend) GetDNSSecKeys() ([]DNSSecKey, error) { + return rb.dnssecKeys.toGeneric(), nil +} + +// GetURISignKeys converts the pg internal key representation into the common representation (URISignKey). +func (rb *RiakBackend) GetURISignKeys() ([]URISignKey, error) { + return rb.uriSigningKeys.toGeneric(), nil +} + +// GetURLSigKeys converts the backends internal key representation into the common representation (URLSigKey). +func (rb *RiakBackend) GetURLSigKeys() ([]URLSigKey, error) { + return rb.urlSigKeys.toGeneric(), nil +} + +// Fetch gets all of the keys from the backend DB. +func (rb *RiakBackend) Fetch() error { + if err := rb.sslKeys.gatherKeys(rb.cluster, rb.cfg.Timeout); err != nil { + return err + } + if err := rb.dnssecKeys.gatherKeys(rb.cluster, rb.cfg.Timeout); err != nil { + return err + } + if err := rb.urlSigKeys.gatherKeys(rb.cluster, rb.cfg.Timeout); err != nil { + return err + } + if err := rb.uriSigningKeys.gatherKeys(rb.cluster, rb.cfg.Timeout); err != nil { + return err + } + + return nil +} + +type riakSSLKeyRecord struct { + tc.DeliveryServiceSSLKeys + Version string +} +type riakSSLKeyTable struct { + Records []riakSSLKeyRecord +} + +func (tbl *riakSSLKeyTable) gatherKeys(cluster *riak.Cluster, timeout int) error { + searchDocs, err := search(cluster, INDEX_SSL, "cdn:*", "", 1000, SCHEMA_SSL_FIELDS[:]) + if err != nil { + return fmt.Errorf("RiakSSLKey gatherKeys: %w", err) + } + + tbl.Records = make([]riakSSLKeyRecord, len(searchDocs)) + for i, doc := range searchDocs { + objs, err := getObject(cluster, doc.Bucket, doc.Key, timeout) + if err != nil { + return err + } + if len(objs) < 1 { + return fmt.Errorf("RiakSSLKey gatherKeys unable to find any objects with key %s and bucket %s, but search results were returned", doc.Key, doc.Bucket) + } + if len(objs) > 1 { + return fmt.Errorf("RiakSSLKey gatherKeys key '%s' more than 1 ssl key record found %d\n", doc.Key, len(objs)) + } + var obj tc.DeliveryServiceSSLKeys + if err = json.Unmarshal(objs[0].Value, &obj); err != nil { + return fmt.Errorf("RiakSSLKey gatherKeys key '%s' unable to unmarshal object into tc.DeliveryServiceSSLKeys: %w", doc.Key, err) + } + tbl.Records[i] = riakSSLKeyRecord{ + DeliveryServiceSSLKeys: obj, + Version: strings.Split(objs[0].Key, "-")[1], + } + } + return nil +} +func (tbl *riakSSLKeyTable) toGeneric() []SSLKey { + keys := make([]SSLKey, len(tbl.Records)) + + for i, record := range tbl.Records { + keys[i] = SSLKey{ + DeliveryServiceSSLKeys: record.DeliveryServiceSSLKeys, + Version: record.Version, + } + } + + return keys +} +func (tbl *riakSSLKeyTable) fromGeneric(keys []SSLKey) { + tbl.Records = make([]riakSSLKeyRecord, len(keys)) + + for i, record := range keys { + tbl.Records[i] = riakSSLKeyRecord{ + DeliveryServiceSSLKeys: record.DeliveryServiceSSLKeys, + Version: record.Version, + } + } +} +func (tbl *riakSSLKeyTable) insertKeys(cluster *riak.Cluster, timeout int) error { + for _, record := range tbl.Records { + objBytes, err := json.Marshal(record.DeliveryServiceSSLKeys) + if err != nil { + return fmt.Errorf("RiakSSLKey insertKeys '%s' failed to marshal keys: %w", record.Key, err) + } + if err = setObject(cluster, makeRiakObject(objBytes, record.DeliveryService+"-"+record.Version), BUCKET_SSL, timeout); err != nil { + return fmt.Errorf("RiakSSLKey insertKeys '%s': %w", record.Key, err) + } + } + return nil +} +func (tbl *riakSSLKeyTable) validate() []string { + errs := []string{} + for _, record := range tbl.Records { + if record.DeliveryService == "" { + errs = append(errs, fmt.Sprintf("SSL Key '%s': Delivery Service is blank!", record.Key)) + } + if record.CDN == "" { + errs = append(errs, fmt.Sprintf("SSL Key '%s': CDN is blank!", record.Key)) + } + if record.Version == "" { + errs = append(errs, fmt.Sprintf("SSL Key '%s': Version is blank!", record.Key)) + } + } + return errs +} + +type riakDNSSecKeyRecord struct { + CDN string + Key tc.DNSSECKeysRiak Review comment: nit: `tc.DNSSECKeysRiak` is deprecated, `tc.DNSSECKeysTrafficVault` is the same thing and could be used instead ########## File path: traffic_ops/app/db/traffic_vault_migrate/postgres.go ########## @@ -0,0 +1,744 @@ +package main + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import ( + "database/sql" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/apache/trafficcontrol/lib/go-log" + "github.com/apache/trafficcontrol/lib/go-tc" + util "github.com/apache/trafficcontrol/lib/go-util" + + _ "github.com/lib/pq" +) + +// PGConfig represents the configuration options available to the PG backend. +type PGConfig struct { + Host string `json:"host"` + Port string `json:"port"` + User string `json:"user"` + Password string `json:"password"` + SSLMode string `json:"sslmode"` + Database string `json:"database"` + Key string `json:"aesKey"` + AESKey []byte +} + +// PGBackend is the Postgres implementation of TVBackend. +type PGBackend struct { + sslKey pgSSLKeyTable + dnssec pgDNSSecTable + uriSigningKeys pgURISignKeyTable + urlSigKeys pgURLSigKeyTable + cfg PGConfig + db *sql.DB +} + +// String returns a high level overview of the backend and its keys. +func (pg *PGBackend) String() string { + data := fmt.Sprintf("PG server %s@%s:%s\n", pg.cfg.User, pg.cfg.Host, pg.cfg.Port) + data += fmt.Sprintf("\tSSL Keys: %d\n", len(pg.sslKey.Records)) + data += fmt.Sprintf("\tDNSSec Keys: %d\n", len(pg.dnssec.Records)) + data += fmt.Sprintf("\tURI Signing Keys: %d\n", len(pg.uriSigningKeys.Records)) + data += fmt.Sprintf("\tURL Sig Keys: %d\n", len(pg.urlSigKeys.Records)) + return data +} + +// Name returns the name for this backend. +func (pg *PGBackend) Name() string { + return "PG" +} + +// ReadConfigFile takes in a filename and will read it into the backends config. +func (pg *PGBackend) ReadConfigFile(configFile string) error { + var err error + if err = UnmarshalConfig(configFile, &pg.cfg); err != nil { + return err + } + + if pg.cfg.AESKey, err = base64.StdEncoding.DecodeString(pg.cfg.Key); err != nil { + return fmt.Errorf("unable to decode PG AESKey '%s': %w", pg.cfg.Key, err) + } + + if err = util.ValidateAESKey(pg.cfg.AESKey); err != nil { + return fmt.Errorf("unable to validate PG AESKey '%s'", pg.cfg.Key) + } + return nil +} + +// Insert takes the current keys and inserts them into the backend DB. +func (pg *PGBackend) Insert() error { + if err := pg.sslKey.insertKeys(pg.db); err != nil { + return err + } + if err := pg.dnssec.insertKeys(pg.db); err != nil { + return err + } + if err := pg.urlSigKeys.insertKeys(pg.db); err != nil { + return err + } + if err := pg.uriSigningKeys.insertKeys(pg.db); err != nil { + return err + } + return nil +} + +// Start initiates the connection to the backend DB. +func (pg *PGBackend) Start() error { + sqlStr := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=%s", pg.cfg.User, pg.cfg.Password, pg.cfg.Host, pg.cfg.Port, pg.cfg.Database, pg.cfg.SSLMode) + db, err := sql.Open("postgres", sqlStr) + if err != nil { + sqlStr = strings.Replace(sqlStr, pg.cfg.Password, "*", 1) + return fmt.Errorf("unable to start PG client with connection string '%s': %w", sqlStr, err) + } + + pg.db = db + pg.sslKey = pgSSLKeyTable{} + pg.dnssec = pgDNSSecTable{} + pg.urlSigKeys = pgURLSigKeyTable{} + pg.uriSigningKeys = pgURISignKeyTable{} + + return nil +} + +// ValidateKey validates that the keys are valid (in most cases, certain fields are not null). +func (pg *PGBackend) ValidateKey() []string { + var allErrs []string + if errs := pg.sslKey.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + if errs := pg.dnssec.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + if errs := pg.uriSigningKeys.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + if errs := pg.urlSigKeys.validate(); errs != nil { + allErrs = append(allErrs, errs...) + } + return allErrs +} + +// Close terminates the connection to the backend DB. +func (pg *PGBackend) Close() error { + return pg.db.Close() +} + +// Ping checks the connection to the backend DB. +func (pg *PGBackend) Ping() error { + return pg.db.Ping() +} + +// Fetch gets all of the keys from the backend DB. +func (pg *PGBackend) Fetch() error { + if err := pg.sslKey.gatherKeys(pg.db); err != nil { + return err + } + + if err := pg.dnssec.gatherKeys(pg.db); err != nil { + return err + } + + if err := pg.urlSigKeys.gatherKeys(pg.db); err != nil { + return err + } + + if err := pg.uriSigningKeys.gatherKeys(pg.db); err != nil { + return err + } + + return nil +} + +// GetSSLKeys converts the backends internal key representation into the common representation (SSLKey). +func (pg *PGBackend) GetSSLKeys() ([]SSLKey, error) { + if err := pg.sslKey.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.sslKey.toGeneric(), nil +} + +// SetSSLKeys takes in keys and converts & encrypts the data into the backends internal format. +func (pg *PGBackend) SetSSLKeys(keys []SSLKey) error { + pg.sslKey.fromGeneric(keys) + return pg.sslKey.encrypt(pg.cfg.AESKey) +} + +// GetDNSSecKeys converts the backends internal key representation into the common representation (DNSSecKey). +func (pg *PGBackend) GetDNSSecKeys() ([]DNSSecKey, error) { + if err := pg.dnssec.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.dnssec.toGeneric(), nil +} + +// SetDNSSecKeys takes in keys and converts & encrypts the data into the backends internal format. +func (pg *PGBackend) SetDNSSecKeys(keys []DNSSecKey) error { + pg.dnssec.fromGeneric(keys) + return pg.dnssec.encrypt(pg.cfg.AESKey) +} + +// GetURISignKeys converts the pg internal key representation into the common representation (URISignKey). +func (pg *PGBackend) GetURISignKeys() ([]URISignKey, error) { + if err := pg.uriSigningKeys.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.uriSigningKeys.toGeneric(), nil +} + +// SetURISignKeys takes in keys and converts & encrypts the data into the backends internal format. +func (pg *PGBackend) SetURISignKeys(keys []URISignKey) error { + pg.uriSigningKeys.fromGeneric(keys) + return pg.uriSigningKeys.encrypt(pg.cfg.AESKey) +} + +// GetURLSigKeys converts the backends internal key representation into the common representation (URLSigKey). +func (pg *PGBackend) GetURLSigKeys() ([]URLSigKey, error) { + if err := pg.urlSigKeys.decrypt(pg.cfg.AESKey); err != nil { + return nil, err + } + return pg.urlSigKeys.toGeneric(), nil +} + +// SetURLSigKeys takes in keys and converts & encrypts the data into the backends internal format. +func (pg *PGBackend) SetURLSigKeys(keys []URLSigKey) error { + pg.urlSigKeys.fromGeneric(keys) + return pg.urlSigKeys.encrypt(pg.cfg.AESKey) +} + +type pgCommonRecord struct { + DataEncrypted []byte +} + +type pgDNSSecRecord struct { + Key tc.DNSSECKeysTrafficVault + CDN string + pgCommonRecord +} +type pgDNSSecTable struct { + Records []pgDNSSecRecord +} + +func (tbl *pgDNSSecTable) gatherKeys(db *sql.DB) error { + sz, err := getSize(db, "dnssec") + if err != nil { + log.Errorln("PGDNSSec gatherKeys: unable to determine size of dnssec table") + } + tbl.Records = make([]pgDNSSecRecord, sz) + + query := "SELECT cdn, data from dnssec" + rows, err := db.Query(query) + if err != nil { + return fmt.Errorf("PGDNSSec gatherKeys: unable to run query '%s': %w", query, err) + } + defer log.Close(rows, "closing dnssec query") + i := 0 + for rows.Next() { + if i > len(tbl.Records)-1 { + return fmt.Errorf("PGDNSSec gatherKeys got more results than expected %d", len(tbl.Records)) + } + if err := rows.Scan(&tbl.Records[i].CDN, &tbl.Records[i].DataEncrypted); err != nil { + return fmt.Errorf("PGDNSSec gatherKeys unable to scan row: %w", err) + } + i += 1 + } + return nil +} +func (tbl *pgDNSSecTable) decrypt(aesKey []byte) error { + for i, _ := range tbl.Records { + if err := decryptInto(aesKey, tbl.Records[i].DataEncrypted, &tbl.Records[i].Key); err != nil { + return fmt.Errorf("unable to decrypt into keys: %w", err) + } + } + return nil +} +func (tbl *pgDNSSecTable) encrypt(aesKey []byte) error { + for i, dns := range tbl.Records { + data, err := json.Marshal(&dns.Key) + if err != nil { + return fmt.Errorf("encrypt issue marshalling keys: %w", err) + } + dat, err := encrypt(data, aesKey) + if err != nil { + return fmt.Errorf("encrypt error: %w", err) + } + tbl.Records[i].DataEncrypted = dat + } + return nil +} +func (tbl *pgDNSSecTable) toGeneric() []DNSSecKey { + keys := make([]DNSSecKey, len(tbl.Records)) + + for i, record := range tbl.Records { + keys[i] = DNSSecKey{ + CDN: record.CDN, + DNSSECKeysTrafficVault: record.Key, + } + } + + return keys +} +func (tbl *pgDNSSecTable) fromGeneric(keys []DNSSecKey) { + tbl.Records = make([]pgDNSSecRecord, len(keys)) + + for i, key := range keys { + tbl.Records[i] = pgDNSSecRecord{ + Key: key.DNSSECKeysTrafficVault, + CDN: key.CDN, + pgCommonRecord: pgCommonRecord{ + DataEncrypted: nil, + }, + } + } +} +func (tbl *pgDNSSecTable) validate() []string { + for _, record := range tbl.Records { + if record.DataEncrypted == nil && len(record.Key) > 0 { + return []string{fmt.Sprintf("DNSSEC Key CDN '%s': DataEncrypted is blank!", record.CDN)} + } + } + return nil +} +func (tbl *pgDNSSecTable) insertKeys(db *sql.DB) error { + queryBase := "INSERT INTO dnssec (cdn, data) VALUES %s ON CONFLICT (cdn) DO UPDATE SET data = EXCLUDED.data" + stride := 2 + queryArgs := make([]interface{}, len(tbl.Records)*stride) + for i, record := range tbl.Records { + j := i * stride + queryArgs[j] = record.CDN + queryArgs[j+1] = record.DataEncrypted + } + return insertIntoTable(db, queryBase, stride, queryArgs) +} + +type pgSSLKeyRecord struct { + Keys tc.DeliveryServiceSSLKeys + pgCommonRecord + + // These records are stored on the table but are duplicated + DeliveryService string + CDN string + Version string +} +type pgSSLKeyTable struct { + Records []pgSSLKeyRecord +} + +func (tbl *pgSSLKeyTable) insertKeys(db *sql.DB) error { + queryBase := "INSERT INTO sslkey (deliveryservice, data, cdn, version) VALUES %s ON CONFLICT (deliveryservice,cdn,version) DO UPDATE SET data = EXCLUDED.data" + stride := 4 + queryArgs := make([]interface{}, len(tbl.Records)*stride) + for i, record := range tbl.Records { + j := i * stride + + queryArgs[j] = record.DeliveryService + queryArgs[j+1] = record.DataEncrypted + queryArgs[j+2] = record.CDN + queryArgs[j+3] = record.Version + } + return insertIntoTable(db, queryBase, 4, queryArgs) +} +func (tbl *pgSSLKeyTable) gatherKeys(db *sql.DB) error { + sz, err := getSize(db, "sslkey") + if err != nil { + return fmt.Errorf("PGSSLKey gatherKeys unable to determine size of sslkey table: %w", err) + } + tbl.Records = make([]pgSSLKeyRecord, sz) + + query := "SELECT data, deliveryservice, cdn, version from sslkey" + rows, err := db.Query(query) + if err != nil { + return fmt.Errorf("PGSSLKey gatherKeys unable to run query '%s': %w", query, err) + } + defer log.Close(rows, "closing sslkey query") + i := 0 + for rows.Next() { + if i > len(tbl.Records)-1 { + return fmt.Errorf("PGSSLKey gatherKeys: got more results than expected") + } + if err := rows.Scan(&tbl.Records[i].DataEncrypted, &tbl.Records[i].DeliveryService, &tbl.Records[i].CDN, &tbl.Records[i].Version); err != nil { + return fmt.Errorf("PGSSLKey gatherKeys unable to scan %d row: %w", i, err) + } + i += 1 + } + return nil +} +func (tbl *pgSSLKeyTable) decrypt(aesKey []byte) error { + for i, dns := range tbl.Records { + if err := decryptInto(aesKey, dns.DataEncrypted, &tbl.Records[i].Keys); err != nil { + return fmt.Errorf("unable to decrypt into keys: %w", err) + } + } + return nil +} +func (tbl *pgSSLKeyTable) encrypt(aesKey []byte) error { + for i, dns := range tbl.Records { Review comment: nit: `dns`? Must be a copy-paste remnant ########## File path: traffic_ops/app/db/traffic_vault_migrate/traffic_vault_migrate.go ########## @@ -0,0 +1,611 @@ +package main + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import ( + "encoding/json" + "errors" + "fmt" + "io/ioutil" + stdlog "log" + "os" + "reflect" + "sort" + "strings" + + "github.com/pborman/getopt/v2" + + "github.com/apache/trafficcontrol/lib/go-log" + "github.com/apache/trafficcontrol/lib/go-tc" +) + +var ( + fromType string + toType string + fromCfgPath string + toCfgPath string + logCfgPath string + keyFile string + dry bool + compare bool + noConfirm bool + dump bool + logLevel string + + cfg config = config{ + LogLocationError: log.LogLocationStderr, + LogLocationWarning: log.LogLocationStdout, + LogLocationInfo: log.LogLocationStdout, + LogLocationDebug: log.LogLocationNull, + LogLocationEvent: log.LogLocationNull, + } + riakBE RiakBackend = RiakBackend{} + pgBE PGBackend = PGBackend{} +) + +func init() { + fromTypePtr := getopt.StringLong("fromType", 't', riakBE.Name(), fmt.Sprintf("From server types (%v)", strings.Join(supportedTypes(), "|"))) + if fromTypePtr == nil { + stdlog.Fatal("unable to load fromType") + } + fromType = *fromTypePtr + + toTypePtr := getopt.StringLong("toType", 'o', pgBE.Name(), fmt.Sprintf("From server types (%v)", strings.Join(supportedTypes(), "|"))) + if toTypePtr == nil { + stdlog.Fatal("unable to load toType") + } + toType = *toTypePtr + + toCfgPtr := getopt.StringLong("toCfgPath", 'g', "pg.json", "To server config file") + if toCfgPtr == nil { + stdlog.Fatal("unable to load toCfg") + } + toCfgPath = *toCfgPtr + + fromCfgPtr := getopt.StringLong("fromCfgPath", 'f', "riak.json", "From server config file") + if fromCfgPtr == nil { + stdlog.Fatal("unable to load fromCfg") + } + fromCfgPath = *fromCfgPtr + + getopt.FlagLong(&dry, "dry", 'r', "Do not perform writes"). + SetOptional(). + SetFlag(). + SetGroup("no_insert") + + getopt.FlagLong(&compare, "compare", 'c', "Compare to and from server records"). + SetOptional(). + SetFlag(). + SetGroup("no_insert") + + getopt.FlagLong(&noConfirm, "noConfirm", 'm', "Requires confirmation before inserting records"). + SetFlag() + + getopt.FlagLong(&dump, "dump", 'd', "Write keys (from 'from' server) to disk"). + SetOptional(). + SetGroup("disk_bck"). + SetFlag() + + getopt.FlagLong(&keyFile, "fill", 'i', "Insert data into `to` server with data this directory"). + SetOptional(). + SetGroup("disk_bck") + + getopt.FlagLong(&logCfgPath, "logCfg", 'l', "Log configuration file"). + SetOptional(). + SetGroup("log") + + getopt.FlagLong(&logLevel, "logLevel", 'e', "Print everything at above specified log level (error|warning|info|debug|event)"). + SetOptional(). + SetGroup("log") +} + +// supportBackends returns the backends available in this tool. +func supportedBackends() []TVBackend { + return []TVBackend{ + &riakBE, &pgBE, + } +} + +func main() { + getopt.ParseV2() + + initConfig() + + var fromSrv TVBackend + var toSrv TVBackend + + importData := keyFile != "" + toSrvUsed := !dump && !dry || keyFile != "" + + if !importData { + log.Infof("Initiating fromSrv %s...\n", fromType) + if !validateType(fromType) { + log.Errorln("Unknown fromType " + fromType) + os.Exit(1) + } + fromSrv = getBackendFromType(fromType) + if err := fromSrv.ReadConfigFile(fromCfgPath); err != nil { + log.Errorf("Unable to read fromSrv cfg: %v", err) + os.Exit(1) + } + + if err := fromSrv.Start(); err != nil { + log.Errorf("issue starting fromSrv: %v", err) + os.Exit(1) + } + defer log.Close(fromSrv, "closing fromSrv") + + if err := fromSrv.Ping(); err != nil { + log.Errorf("Unable to ping fromSrv: %v", err) + os.Exit(1) + } + } + + if toSrvUsed { + log.Infof("Initiating toSrv %s...\n", toType) + if !validateType(toType) { + log.Errorln("Unknown toType " + toType) + os.Exit(1) + } + toSrv = getBackendFromType(toType) + + if err := toSrv.ReadConfigFile(toCfgPath); err != nil { + log.Errorf("Unable to read toSrv cfg: %v", err) + os.Exit(1) + } + + if err := toSrv.Start(); err != nil { + log.Errorf("issue starting toSrv: %v", err) + os.Exit(1) + } + defer log.Close(toSrv, "closing toSrv") + + if err := toSrv.Ping(); err != nil { + log.Errorf("Unable to ping toSrv: %v", err) + os.Exit(1) + } + } + + var fromSecret Secrets + if !importData { + var err error + log.Infof("Fetching data from %s...\n", fromSrv.Name()) + if err = fromSrv.Fetch(); err != nil { + log.Errorf("Unable to fetch fromSrv data: %v", err) + os.Exit(1) + } + + if fromSecret, err = GetKeys(fromSrv); err != nil { + log.Errorln(err) + os.Exit(1) + } + + if err := Validate(fromSrv); err != nil { + log.Errorln(err) + os.Exit(1) + } + + } else { + err := fromSecret.fill(keyFile) + if err != nil { + log.Errorln(err) + os.Exit(1) + } + } + + if dump { + log.Infof("Dumping data from %s...\n", fromSrv.Name()) + fromSecret.dump("dump") + return + } + + if compare { + log.Infof("Fetching data from %s...\n", toSrv.Name()) + if err := toSrv.Fetch(); err != nil { + log.Errorf("Unable to fetch toSrv data: %v\n", err) + os.Exit(1) + } + + toSecret, err := GetKeys(toSrv) + if err != nil { + log.Errorln(err) + os.Exit(1) + } + log.Infoln("Validating " + toSrv.Name()) + if err := toSrv.ValidateKey(); err != nil && len(err) > 0 { + log.Errorln(strings.Join(err, "\n")) + os.Exit(1) + } + + fromSecret.sort() + toSecret.sort() + + if !importData { + log.Infoln(fromSrv.String()) + } else { + log.Infof("Disk backup:\n\tSSL Keys: %d\n\tDNSSec Keys: %d\n\tURI Keys: %d\n\tURL Keys: %d\n", len(fromSecret.sslkeys), len(fromSecret.dnssecKeys), len(fromSecret.uriKeys), len(fromSecret.urlKeys)) + } + log.Infoln(toSrv.String()) + + if !reflect.DeepEqual(fromSecret.sslkeys, toSecret.sslkeys) { + log.Errorln("from sslkeys and to sslkeys don't match") + os.Exit(1) + } + if !reflect.DeepEqual(fromSecret.dnssecKeys, toSecret.dnssecKeys) { + log.Errorln("from dnssec and to dnssec don't match") + os.Exit(1) + } + if !reflect.DeepEqual(fromSecret.uriKeys, toSecret.uriKeys) { + log.Errorln("from uri and to uri don't match") + os.Exit(1) + } + if !reflect.DeepEqual(fromSecret.urlKeys, toSecret.urlKeys) { + log.Errorln("from url and to url don't match") + os.Exit(1) + } + log.Infoln("Both data sources have the same keys") + return + } + + if toSrvUsed { + log.Infof("Setting %s keys...\n", toSrv.Name()) + if err := SetKeys(toSrv, fromSecret); err != nil { + log.Errorln(err) + os.Exit(1) + } + + if err := Validate(toSrv); err != nil { + log.Errorln(err) + os.Exit(1) + } + } + + if !importData { + log.Infoln(fromSrv.String()) + } else { + log.Infof("Disk backup:\n\tSSL Keys: %d\n\tDNSSec Keys: %d\n\tURI Keys: %d\n\tURL Keys: %d\n", len(fromSecret.sslkeys), len(fromSecret.dnssecKeys), len(fromSecret.uriKeys), len(fromSecret.urlKeys)) + } + + if dry { + return + } + + if !noConfirm { + ans := "q" + for { + fmt.Print("Confirm data insertion (y/n): ") + if _, err := fmt.Scanln(&ans); err != nil { + log.Errorln("unable to get user input") + os.Exit(1) + } + + if ans == "y" { + break + } else if ans == "n" { + return + } + } + } + log.Infof("Inserting data into %s...\n", toSrv.Name()) + if err := toSrv.Insert(); err != nil { + log.Errorln(err) + os.Exit(1) + } +} + +// Validate runs the ValidateKey method on the backend. +func Validate(be TVBackend) error { + if errs := be.ValidateKey(); errs != nil && len(errs) > 0 { + return errors.New(fmt.Sprintf("Validation Errors (%s): \n%s", be.Name(), strings.Join(errs, "\n"))) + } + return nil +} + +// SetKeys will set all of the keys for a backend. +func SetKeys(be TVBackend, s Secrets) error { + if err := be.SetSSLKeys(s.sslkeys); err != nil { + return fmt.Errorf("Unable to set %s ssl keys: %w", be.Name(), err) + } + if err := be.SetDNSSecKeys(s.dnssecKeys); err != nil { + return fmt.Errorf("Unable to set %s dnssec keys: %w", be.Name(), err) + } + if err := be.SetURLSigKeys(s.urlKeys); err != nil { + return fmt.Errorf("Unable to set %v url keys: %v", be.Name(), err) + } + if err := be.SetURISignKeys(s.uriKeys); err != nil { + return fmt.Errorf("Unable to set %v uri keys: %v", be.Name(), err) + } + return nil +} + +// GetKeys will get all of the keys for a backend. +func GetKeys(be TVBackend) (Secrets, error) { + var secret Secrets + var err error + if secret.sslkeys, err = be.GetSSLKeys(); err != nil { + return Secrets{}, fmt.Errorf("Unable to get %v sslkeys: %v", be.Name(), err) + } + if secret.dnssecKeys, err = be.GetDNSSecKeys(); err != nil { + return Secrets{}, fmt.Errorf("Unable to get %v dnssec keys: %v", be.Name(), err) + } + if secret.uriKeys, err = be.GetURISignKeys(); err != nil { + return Secrets{}, fmt.Errorf("Unable to get %v uri keys: %v", be.Name(), err) + } + if secret.urlKeys, err = be.GetURLSigKeys(); err != nil { + return Secrets{}, fmt.Errorf("Unable to %v url keys: %v", be.Name(), err) + } + return secret, nil +} + +// UnmarshalConfig takes in a config file and a type and will read the config file into the reflected type. +func UnmarshalConfig(configFile string, config interface{}) error { + data, err := ioutil.ReadFile(configFile) + if err != nil { + return err + } + err = json.Unmarshal(data, config) + if err != nil { + return err + } + + return nil +} + +// TVBackend represents a TV backend that can be have data migrated to/from +type TVBackend interface { Review comment: I like that this is interface-based 👍 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected]
