This is an automated email from the ASF dual-hosted git repository.
HTHou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iotdb-client-go.git
The following commit(s) were added to refs/heads/main by this push:
new 1a1ff84 Add TLS and mTLS support (#165)
1a1ff84 is described below
commit 1a1ff84289724216bd474cdd455c1e13ea68c2d6
Author: Haonan <[email protected]>
AuthorDate: Sat Jun 27 19:26:37 2026 +0800
Add TLS and mTLS support (#165)
---
README.md | 23 +++++++
README_ZH.md | 23 +++++++
client/session.go | 58 +++++++++---------
client/sessionpool.go | 3 +
client/tls.go | 120 ++++++++++++++++++++++++++++++++++++
client/tls_test.go | 167 ++++++++++++++++++++++++++++++++++++++++++++++++++
6 files changed, 365 insertions(+), 29 deletions(-)
diff --git a/README.md b/README.md
index a963a8e..2045b53 100644
--- a/README.md
+++ b/README.md
@@ -79,6 +79,29 @@ curl -o session_example.go -L
https://github.com/apache/iotdb-client-go/raw/main
go run session_example.go
```
+## TLS/mTLS
+
+Set `TLSConfig` on `client.Config`, `client.ClusterConfig`, or
`client.PoolConfig` to enable TLS. Add `CertFile` and `KeyFile` when the server
requires mTLS client authentication.
+
+```golang
+config := &client.Config{
+ Host: host,
+ Port: port,
+ UserName: user,
+ Password: password,
+ TLSConfig: &client.TLSConfig{
+ CAFile: "/path/to/ca.pem",
+ CertFile: "/path/to/client.pem",
+ KeyFile: "/path/to/client-key.pem",
+ },
+}
+session := client.NewSession(config)
+if err := session.Open(false, 0); err != nil {
+ log.Fatal(err)
+}
+defer session.Close()
+```
+
## How to Use the SessionPool
SessionPool is a wrapper of a Session Set. Using SessionPool, the user do not
need to consider how to reuse a session connection.
diff --git a/README_ZH.md b/README_ZH.md
index e856172..e680123 100644
--- a/README_ZH.md
+++ b/README_ZH.md
@@ -76,6 +76,29 @@ curl -o session_example.go -L
https://github.com/apache/iotdb-client-go/raw/main
go run session_example.go
```
+## TLS/mTLS
+
+在 `client.Config`、`client.ClusterConfig` 或 `client.PoolConfig` 上设置 `TLSConfig`
即可启用 TLS。如果服务端要求 mTLS 客户端认证,同时设置 `CertFile` 和 `KeyFile`。
+
+```golang
+config := &client.Config{
+ Host: host,
+ Port: port,
+ UserName: user,
+ Password: password,
+ TLSConfig: &client.TLSConfig{
+ CAFile: "/path/to/ca.pem",
+ CertFile: "/path/to/client.pem",
+ KeyFile: "/path/to/client-key.pem",
+ },
+}
+session := client.NewSession(config)
+if err := session.Open(false, 0); err != nil {
+ log.Fatal(err)
+}
+defer session.Close()
+```
+
## SessionPool
通过SessionPool管理session,用户不需要考虑如何重用session,当到达pool的最大值时,获取session的请求会阻塞
注意:session使用完成后需要调用PutBack方法
diff --git a/client/session.go b/client/session.go
index 28b326e..1a4cce4 100644
--- a/client/session.go
+++ b/client/session.go
@@ -26,7 +26,6 @@ import (
"errors"
"fmt"
"log"
- "net"
"reflect"
"sort"
"strings"
@@ -68,6 +67,7 @@ type Config struct {
sqlDialect string
Version Version
Database string
+ TLSConfig *TLSConfig
}
type Session struct {
@@ -100,13 +100,10 @@ func (s *Session) Open(enableRPCCompression bool,
connectionTimeoutInMs int) err
var err error
- // In thrift 0.14.1, this func returns two values; in newer versions,
it returns one.
- s.trans = thrift.NewTSocketConf(net.JoinHostPort(s.config.Host,
s.config.Port), &thrift.TConfiguration{
- ConnectTimeout: time.Duration(connectionTimeoutInMs) *
time.Millisecond, // Use 0 for no timeout
- })
- // s.trans = thrift.NewTFramedTransport(s.trans) // deprecated
- tmp_conf := thrift.TConfiguration{MaxFrameSize:
thrift.DEFAULT_MAX_FRAME_SIZE}
- s.trans = thrift.NewTFramedTransportConf(s.trans, &tmp_conf)
+ s.trans, err = newTransport(s.config.Host, s.config.Port,
connectionTimeoutInMs, s.config.TLSConfig)
+ if err != nil {
+ return err
+ }
if !s.trans.IsOpen() {
err = s.trans.Open()
if err != nil {
@@ -154,6 +151,7 @@ type ClusterConfig struct {
ConnectRetryMax int
sqlDialect string
Database string
+ TLSConfig *TLSConfig
}
func (s *Session) OpenCluster(enableRPCCompression bool) error {
@@ -1326,26 +1324,31 @@ func newClusterSessionWithSqlDialect(clusterConfig
*ClusterConfig) (Session, err
session.endPointList[i] = node
}
var err error
+ var lastErr error
for i := range session.endPointList {
ep := session.endPointList[i]
- session.trans = thrift.NewTSocketConf(net.JoinHostPort(ep.Host,
ep.Port), &thrift.TConfiguration{
- ConnectTimeout: time.Duration(0), // Use 0 for no
timeout
- })
- // session.trans = thrift.NewTFramedTransport(session.trans)
// deprecated
- tmp_conf := thrift.TConfiguration{MaxFrameSize:
thrift.DEFAULT_MAX_FRAME_SIZE}
- session.trans = thrift.NewTFramedTransportConf(session.trans,
&tmp_conf)
+ session.trans, err = newTransport(ep.Host, ep.Port, 0,
clusterConfig.TLSConfig)
+ if err != nil {
+ lastErr = err
+ log.Println(err)
+ continue
+ }
if !session.trans.IsOpen() {
err = session.trans.Open()
if err != nil {
+ lastErr = err
log.Println(err)
} else {
session.config = getConfig(ep.Host, ep.Port,
- clusterConfig.UserName,
clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone,
clusterConfig.ConnectRetryMax, clusterConfig.Database, clusterConfig.sqlDialect)
+ clusterConfig.UserName,
clusterConfig.Password, clusterConfig.FetchSize, clusterConfig.TimeZone,
clusterConfig.ConnectRetryMax, clusterConfig.Database,
clusterConfig.sqlDialect, clusterConfig.TLSConfig)
break
}
}
}
- if !session.trans.IsOpen() {
+ if session.trans == nil || !session.trans.IsOpen() {
+ if lastErr != nil {
+ return session, fmt.Errorf("no server can connect: %w",
lastErr)
+ }
return session, fmt.Errorf("no server can connect")
}
return session, nil
@@ -1354,18 +1357,14 @@ func newClusterSessionWithSqlDialect(clusterConfig
*ClusterConfig) (Session, err
func (s *Session) initClusterConn(node endPoint) error {
var err error
- s.trans = thrift.NewTSocketConf(net.JoinHostPort(node.Host, node.Port),
&thrift.TConfiguration{
- ConnectTimeout: time.Duration(0), // Use 0 for no timeout
- })
- if err == nil {
- // s.trans = thrift.NewTFramedTransport(s.trans) //
deprecated
- tmp_conf := thrift.TConfiguration{MaxFrameSize:
thrift.DEFAULT_MAX_FRAME_SIZE}
- s.trans = thrift.NewTFramedTransportConf(s.trans, &tmp_conf)
- if !s.trans.IsOpen() {
- err = s.trans.Open()
- if err != nil {
- return err
- }
+ s.trans, err = newTransport(node.Host, node.Port, 0, s.config.TLSConfig)
+ if err != nil {
+ return err
+ }
+ if !s.trans.IsOpen() {
+ err = s.trans.Open()
+ if err != nil {
+ return err
}
}
@@ -1398,7 +1397,7 @@ func (s *Session) initClusterConn(node endPoint) error {
return err
}
-func getConfig(host string, port string, userName string, passWord string,
fetchSize int32, timeZone string, connectRetryMax int, database string,
sqlDialect string) *Config {
+func getConfig(host string, port string, userName string, passWord string,
fetchSize int32, timeZone string, connectRetryMax int, database string,
sqlDialect string, tlsConfig *TLSConfig) *Config {
return &Config{
Host: host,
Port: port,
@@ -1409,6 +1408,7 @@ func getConfig(host string, port string, userName string,
passWord string, fetch
ConnectRetryMax: connectRetryMax,
sqlDialect: sqlDialect,
Database: database,
+ TLSConfig: tlsConfig,
}
}
diff --git a/client/sessionpool.go b/client/sessionpool.go
index 757b298..c481bf4 100644
--- a/client/sessionpool.go
+++ b/client/sessionpool.go
@@ -50,6 +50,7 @@ type PoolConfig struct {
TimeZone string
ConnectRetryMax int
Database string
+ TLSConfig *TLSConfig
sqlDialect string
}
@@ -146,6 +147,7 @@ func getSessionConfig(config *PoolConfig) *Config {
ConnectRetryMax: config.ConnectRetryMax,
sqlDialect: config.sqlDialect,
Database: config.Database,
+ TLSConfig: config.TLSConfig,
}
}
@@ -159,6 +161,7 @@ func getClusterSessionConfig(config *PoolConfig)
*ClusterConfig {
ConnectRetryMax: config.ConnectRetryMax,
sqlDialect: config.sqlDialect,
Database: config.Database,
+ TLSConfig: config.TLSConfig,
}
}
diff --git a/client/tls.go b/client/tls.go
new file mode 100644
index 0000000..4e09b30
--- /dev/null
+++ b/client/tls.go
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package client
+
+import (
+ "crypto/tls"
+ "crypto/x509"
+ "fmt"
+ "net"
+ "os"
+ "time"
+
+ "github.com/apache/thrift/lib/go/thrift"
+)
+
+// TLSConfig enables TLS for an IoTDB client connection. Set CertFile and
+// KeyFile together to enable mTLS client authentication.
+type TLSConfig struct {
+ // Config is an optional base tls.Config. It is cloned before use.
+ Config *tls.Config
+
+ // CAFile is an optional PEM encoded CA certificate file used to verify
the server.
+ CAFile string
+
+ // CertFile and KeyFile are optional PEM encoded client certificate and
key files for mTLS.
+ CertFile string
+ KeyFile string
+}
+
+func newTransport(host string, port string, connectionTimeoutInMs int,
tlsConfig *TLSConfig) (thrift.TTransport, error) {
+ conf := &thrift.TConfiguration{
+ ConnectTimeout: time.Duration(connectionTimeoutInMs) *
time.Millisecond,
+ MaxFrameSize: thrift.DEFAULT_MAX_FRAME_SIZE,
+ }
+ hostPort := net.JoinHostPort(host, port)
+
+ var base thrift.TTransport
+ if tlsConfig == nil {
+ base = thrift.NewTSocketConf(hostPort, conf)
+ } else {
+ cfg, err := buildTLSConfig(tlsConfig)
+ if err != nil {
+ return nil, err
+ }
+ conf.TLSConfig = cfg
+ base = thrift.NewTSSLSocketConf(hostPort, conf)
+ }
+
+ return thrift.NewTFramedTransportConf(base, conf), nil
+}
+
+func buildTLSConfig(config *TLSConfig) (*tls.Config, error) {
+ if config == nil {
+ return nil, nil
+ }
+
+ tlsConfig := &tls.Config{}
+ if config.Config != nil {
+ tlsConfig = config.Config.Clone()
+ }
+ if config.CAFile != "" {
+ rootCAs, err := loadCertPool(tlsConfig.RootCAs, config.CAFile)
+ if err != nil {
+ return nil, err
+ }
+ tlsConfig.RootCAs = rootCAs
+ }
+ if config.CertFile != "" || config.KeyFile != "" {
+ if config.CertFile == "" || config.KeyFile == "" {
+ return nil, fmt.Errorf("both TLS CertFile and KeyFile
must be set")
+ }
+ certificate, err := tls.LoadX509KeyPair(config.CertFile,
config.KeyFile)
+ if err != nil {
+ return nil, fmt.Errorf("load TLS client
certificate/key: %w", err)
+ }
+ tlsConfig.Certificates = append(tlsConfig.Certificates,
certificate)
+ }
+
+ return tlsConfig, nil
+}
+
+func loadCertPool(base *x509.CertPool, caFile string) (*x509.CertPool, error) {
+ rootCAs := base
+ if rootCAs != nil {
+ rootCAs = rootCAs.Clone()
+ } else {
+ systemPool, err := x509.SystemCertPool()
+ if err == nil && systemPool != nil {
+ rootCAs = systemPool
+ } else {
+ rootCAs = x509.NewCertPool()
+ }
+ }
+
+ caCert, err := os.ReadFile(caFile)
+ if err != nil {
+ return nil, fmt.Errorf("read TLS CA file %q: %w", caFile, err)
+ }
+ if !rootCAs.AppendCertsFromPEM(caCert) {
+ return nil, fmt.Errorf("append TLS CA file %q: no certificates
found", caFile)
+ }
+ return rootCAs, nil
+}
diff --git a/client/tls_test.go b/client/tls_test.go
new file mode 100644
index 0000000..c00ff9f
--- /dev/null
+++ b/client/tls_test.go
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package client
+
+import (
+ "crypto/ecdsa"
+ "crypto/elliptic"
+ "crypto/rand"
+ "crypto/tls"
+ "crypto/x509"
+ "crypto/x509/pkix"
+ "encoding/pem"
+ "math/big"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+)
+
+func TestBuildTLSConfigClonesBaseConfig(t *testing.T) {
+ base := &tls.Config{
+ MinVersion: tls.VersionTLS12,
+ }
+
+ cfg, err := buildTLSConfig(&TLSConfig{
+ Config: base,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if cfg == base {
+ t.Fatal("buildTLSConfig must clone the base tls.Config")
+ }
+ if cfg.MinVersion != tls.VersionTLS12 {
+ t.Fatalf("MinVersion = %d, want %d", cfg.MinVersion,
tls.VersionTLS12)
+ }
+}
+
+func TestBuildTLSConfigLoadsFiles(t *testing.T) {
+ caFile, certFile, keyFile := writeTLSFiles(t)
+
+ cfg, err := buildTLSConfig(&TLSConfig{
+ CAFile: caFile,
+ CertFile: certFile,
+ KeyFile: keyFile,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if cfg.RootCAs == nil {
+ t.Fatal("RootCAs should be set")
+ }
+ if len(cfg.Certificates) != 1 {
+ t.Fatalf("Certificates length = %d, want 1",
len(cfg.Certificates))
+ }
+}
+
+func TestBuildTLSConfigRequiresCertAndKey(t *testing.T) {
+ _, err := buildTLSConfig(&TLSConfig{CertFile: "client.crt"})
+ if err == nil {
+ t.Fatal("expected error when CertFile is set without KeyFile")
+ }
+}
+
+func TestNewClusterSessionReturnsTLSConfigError(t *testing.T) {
+ missingCAFile := filepath.Join(t.TempDir(), "missing-ca.pem")
+
+ _, err := newClusterSessionWithSqlDialect(&ClusterConfig{
+ NodeUrls: []string{"127.0.0.1:6667"},
+ TLSConfig: &TLSConfig{CAFile: missingCAFile},
+ })
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "no server can connect") {
+ t.Fatalf("error = %q, want no server can connect", err)
+ }
+ if !strings.Contains(err.Error(), "read TLS CA file") {
+ t.Fatalf("error = %q, want TLS CA file detail", err)
+ }
+}
+
+func writeTLSFiles(t *testing.T) (caFile string, certFile string, keyFile
string) {
+ t.Helper()
+
+ dir := t.TempDir()
+ now := time.Now()
+ caKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+ caTemplate := &x509.Certificate{
+ SerialNumber: big.NewInt(1),
+ Subject: pkix.Name{CommonName:
"iotdb-client-go-test-ca"},
+ NotBefore: now.Add(-time.Hour),
+ NotAfter: now.Add(time.Hour),
+ KeyUsage: x509.KeyUsageCertSign |
x509.KeyUsageDigitalSignature,
+ BasicConstraintsValid: true,
+ IsCA: true,
+ }
+ caDER, err := x509.CreateCertificate(rand.Reader, caTemplate,
caTemplate, &caKey.PublicKey, caKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ clientKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+ clientTemplate := &x509.Certificate{
+ SerialNumber: big.NewInt(2),
+ Subject: pkix.Name{CommonName:
"iotdb-client-go-test-client"},
+ NotBefore: now.Add(-time.Hour),
+ NotAfter: now.Add(time.Hour),
+ KeyUsage: x509.KeyUsageDigitalSignature,
+ ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
+ }
+ clientDER, err := x509.CreateCertificate(rand.Reader, clientTemplate,
caTemplate, &clientKey.PublicKey, caKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ clientKeyDER, err := x509.MarshalECPrivateKey(clientKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ caFile = filepath.Join(dir, "ca.pem")
+ certFile = filepath.Join(dir, "client.pem")
+ keyFile = filepath.Join(dir, "client-key.pem")
+ writePEMFile(t, caFile, "CERTIFICATE", caDER)
+ writePEMFile(t, certFile, "CERTIFICATE", clientDER)
+ writePEMFile(t, keyFile, "EC PRIVATE KEY", clientKeyDER)
+ return caFile, certFile, keyFile
+}
+
+func writePEMFile(t *testing.T, filename string, blockType string, der []byte)
{
+ t.Helper()
+
+ file, err := os.Create(filename)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer file.Close()
+
+ if err := pem.Encode(file, &pem.Block{Type: blockType, Bytes: der});
err != nil {
+ t.Fatal(err)
+ }
+}