taldir

Directory service to resolve wallet mailboxes by messenger addresses
Log | Files | Refs | Submodules | README | LICENSE

ssl.go (8973B)


      1 package pq
      2 
      3 import (
      4 	"bytes"
      5 	"crypto/tls"
      6 	"crypto/x509"
      7 	"encoding/pem"
      8 	"errors"
      9 	"fmt"
     10 	"net"
     11 	"os"
     12 	"path/filepath"
     13 	"slices"
     14 	"strings"
     15 	"sync"
     16 
     17 	"github.com/lib/pq/internal/pqutil"
     18 )
     19 
     20 // Registry for custom tls.Configs
     21 var (
     22 	tlsConfs   = make(map[string]*tls.Config)
     23 	tlsConfsMu sync.RWMutex
     24 )
     25 
     26 // RegisterTLSConfig registers a custom [tls.Config]. They are used by using
     27 // sslmode=pqgo-«key» in the connection string.
     28 //
     29 // Set the config to nil to remove a configuration.
     30 func RegisterTLSConfig(key string, config *tls.Config) error {
     31 	key = strings.TrimPrefix(key, "pqgo-")
     32 	if config == nil {
     33 		tlsConfsMu.Lock()
     34 		delete(tlsConfs, key)
     35 		tlsConfsMu.Unlock()
     36 		return nil
     37 	}
     38 
     39 	tlsConfsMu.Lock()
     40 	tlsConfs[key] = config
     41 	tlsConfsMu.Unlock()
     42 	return nil
     43 }
     44 
     45 func hasTLSConfig(key string) bool {
     46 	tlsConfsMu.RLock()
     47 	defer tlsConfsMu.RUnlock()
     48 	_, ok := tlsConfs[key]
     49 	return ok
     50 }
     51 
     52 func getTLSConfigClone(key string) *tls.Config {
     53 	tlsConfsMu.RLock()
     54 	defer tlsConfsMu.RUnlock()
     55 	if v, ok := tlsConfs[key]; ok {
     56 		return v.Clone()
     57 	}
     58 	return nil
     59 }
     60 
     61 // ssl generates a function to upgrade a net.Conn based on the "sslmode" and
     62 // related settings. The function is nil when no upgrade should take place.
     63 //
     64 // Don't refer to Config.SSLMode here, as the mode in arguments may be different
     65 // in case of sslmode=allow or prefer.
     66 func ssl(cfg Config, mode SSLMode) (func(net.Conn) (net.Conn, error), error) {
     67 	var (
     68 		home = pqutil.Home(true)
     69 		// Don't set defaults here, because tlsConf may be overwritten if a
     70 		// custom one was registered. Set it after the sslmode switch.
     71 		tlsConf = &tls.Config{}
     72 		// Only verify the CA signing but not the hostname.
     73 		verifyCaOnly = false
     74 	)
     75 	if mode.useSSL() && !cfg.SSLInline && cfg.SSLRootCert == "" && home != "" {
     76 		f := filepath.Join(home, "root.crt")
     77 		if _, err := os.Stat(f); err == nil {
     78 			cfg.SSLRootCert = f
     79 		}
     80 	}
     81 	switch {
     82 	case mode == SSLModeDisable || mode == SSLModeAllow:
     83 		return nil, nil
     84 
     85 	case mode == "" || mode == SSLModeRequire || mode == SSLModePrefer:
     86 		// Skip TLS's own verification since it requires full verification.
     87 		tlsConf.InsecureSkipVerify = true
     88 
     89 		// From http://www.postgresql.org/docs/current/static/libpq-ssl.html:
     90 		//
     91 		// For backwards compatibility with earlier versions of PostgreSQL, if a
     92 		// root CA file exists, the behavior of sslmode=require will be the same
     93 		// as that of verify-ca, meaning the server certificate is validated
     94 		// against the CA. Relying on this behavior is discouraged, and
     95 		// applications that need certificate validation should always use
     96 		// verify-ca or verify-full.
     97 		if cfg.SSLRootCert != "" {
     98 			if cfg.SSLInline {
     99 				verifyCaOnly = true
    100 			} else if _, err := os.Stat(cfg.SSLRootCert); err == nil {
    101 				verifyCaOnly = true
    102 			} else if cfg.SSLRootCert != "system" {
    103 				cfg.SSLRootCert = ""
    104 			}
    105 		}
    106 	case mode == SSLModeVerifyCA:
    107 		// Skip TLS's own verification since it requires full verification.
    108 		tlsConf.InsecureSkipVerify = true
    109 		verifyCaOnly = true
    110 	case mode == SSLModeVerifyFull:
    111 		tlsConf.ServerName = cfg.Host
    112 	case strings.HasPrefix(string(mode), "pqgo-"):
    113 		tlsConf = getTLSConfigClone(string(mode[5:]))
    114 		if tlsConf == nil {
    115 			return nil, fmt.Errorf(`pq: unknown custom sslmode %q`, mode)
    116 		}
    117 	default:
    118 		panic("unreachable")
    119 	}
    120 
    121 	tlsConf.MinVersion = cfg.SSLMinProtocolVersion.tlsconf()
    122 	tlsConf.MaxVersion = cfg.SSLMaxProtocolVersion.tlsconf()
    123 
    124 	// RFC 6066 asks to not set SNI if the host is a literal IP address (IPv4 or
    125 	// IPv6). This check is coded already crypto.tls.hostnameInSNI, so just
    126 	// always set ServerName here and let crypto/tls do the filtering.
    127 	if cfg.SSLSNI {
    128 		tlsConf.ServerName = cfg.Host
    129 	}
    130 
    131 	err := sslClientCertificates(tlsConf, cfg, home)
    132 	if err != nil {
    133 		return nil, err
    134 	}
    135 	rootPem, err := sslCertificateAuthority(tlsConf, cfg)
    136 	if err != nil {
    137 		return nil, err
    138 	}
    139 	sslAppendIntermediates(tlsConf, cfg, rootPem)
    140 
    141 	// Accept renegotiation requests initiated by the backend.
    142 	//
    143 	// Renegotiation was deprecated then removed from PostgreSQL 9.5, but the
    144 	// default configuration of older versions has it enabled. Redshift also
    145 	// initiates renegotiations and cannot be reconfigured.
    146 	//
    147 	// TODO: I think this can be removed?
    148 	tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient
    149 
    150 	return func(conn net.Conn) (net.Conn, error) {
    151 		client := tls.Client(conn, tlsConf)
    152 		if verifyCaOnly {
    153 			err := client.Handshake()
    154 			if err != nil {
    155 				return client, err
    156 			}
    157 			var (
    158 				certs = client.ConnectionState().PeerCertificates
    159 				opts  = x509.VerifyOptions{Intermediates: x509.NewCertPool(), Roots: tlsConf.RootCAs}
    160 			)
    161 			for _, cert := range certs[1:] {
    162 				opts.Intermediates.AddCert(cert)
    163 			}
    164 			_, err = certs[0].Verify(opts)
    165 			return client, err
    166 		}
    167 		return client, nil
    168 	}, nil
    169 }
    170 
    171 // sslClientCertificates adds the certificate specified in the "sslcert" and
    172 //
    173 // "sslkey" settings, or if they aren't set, from the .postgresql directory
    174 // in the user's home directory. The configured files must exist and have
    175 // the correct permissions.
    176 func sslClientCertificates(tlsConf *tls.Config, cfg Config, home string) error {
    177 	if cfg.SSLInline {
    178 		cert, err := tls.X509KeyPair([]byte(cfg.SSLCert), []byte(cfg.SSLKey))
    179 		if err != nil {
    180 			return err
    181 		}
    182 		// Use GetClientCertificate instead of the Certificates field. When
    183 		// Certificates is set, Go's TLS client only sends the cert if the
    184 		// server's CertificateRequest includes a CA that issued it. When the
    185 		// client cert was signed by an intermediate CA but the server only
    186 		// advertises the root CA, Go skips sending the cert entirely.
    187 		// GetClientCertificate bypasses this filtering.
    188 		tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
    189 			return &cert, nil
    190 		}
    191 		return nil
    192 	}
    193 
    194 	// Only load client certificate and key if the setting is not blank, like libpq.
    195 	if cfg.SSLCert == "" && home != "" {
    196 		cfg.SSLCert = filepath.Join(home, "postgresql.crt")
    197 	}
    198 	if cfg.SSLCert == "" {
    199 		return nil
    200 	}
    201 	_, err := os.Stat(cfg.SSLCert)
    202 	if err != nil {
    203 		if pqutil.ErrNotExists(err) {
    204 			return nil
    205 		}
    206 		return err
    207 	}
    208 
    209 	// In libpq, the ssl key is only loaded if the setting is not blank.
    210 	if cfg.SSLKey == "" && home != "" {
    211 		cfg.SSLKey = filepath.Join(home, "postgresql.key")
    212 	}
    213 	if cfg.SSLKey != "" {
    214 		err := pqutil.SSLKeyPermissions(cfg.SSLKey)
    215 		if err != nil {
    216 			return err
    217 		}
    218 	}
    219 
    220 	cert, err := tls.LoadX509KeyPair(cfg.SSLCert, cfg.SSLKey)
    221 	if err != nil {
    222 		return err
    223 	}
    224 
    225 	// Using GetClientCertificate instead of Certificates per comment above.
    226 	tlsConf.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
    227 		return &cert, nil
    228 	}
    229 	return nil
    230 }
    231 
    232 var testSystemRoots *x509.CertPool
    233 
    234 // sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting.
    235 func sslCertificateAuthority(tlsConf *tls.Config, cfg Config) ([]byte, error) {
    236 	// Only load root certificate if not blank, like libpq.
    237 	if cfg.SSLRootCert == "" {
    238 		return nil, nil
    239 	}
    240 
    241 	if cfg.SSLRootCert == "system" {
    242 		// No work to do as system CAs are used by default if RootCAs is nil.
    243 		tlsConf.RootCAs = testSystemRoots
    244 		return nil, nil
    245 	}
    246 
    247 	tlsConf.RootCAs = x509.NewCertPool()
    248 
    249 	var cert []byte
    250 	if cfg.SSLInline {
    251 		cert = []byte(cfg.SSLRootCert)
    252 	} else {
    253 		var err error
    254 		cert, err = os.ReadFile(cfg.SSLRootCert)
    255 		if err != nil {
    256 			return nil, err
    257 		}
    258 	}
    259 
    260 	if !tlsConf.RootCAs.AppendCertsFromPEM(cert) {
    261 		return nil, errors.New("pq: couldn't parse pem from sslrootcert")
    262 	}
    263 	return cert, nil
    264 }
    265 
    266 // sslAppendIntermediates appends intermediate CA certificates from sslrootcert
    267 // to the client certificate chain. This is needed so the server can verify the
    268 // client cert when it was signed by an intermediate CA — without this, the TLS
    269 // handshake only sends the leaf client cert.
    270 func sslAppendIntermediates(tlsConf *tls.Config, cfg Config, rootPem []byte) {
    271 	if cfg.SSLRootCert == "" || tlsConf.GetClientCertificate == nil || len(rootPem) == 0 {
    272 		return
    273 	}
    274 
    275 	var (
    276 		pemData       = slices.Clone(rootPem)
    277 		intermediates [][]byte
    278 	)
    279 	for {
    280 		var block *pem.Block
    281 		block, pemData = pem.Decode(pemData)
    282 		if block == nil {
    283 			break
    284 		}
    285 		if block.Type != "CERTIFICATE" {
    286 			continue
    287 		}
    288 		cert, err := x509.ParseCertificate(block.Bytes)
    289 		if err != nil {
    290 			continue
    291 		}
    292 		// Skip self-signed root CAs; only append intermediates.
    293 		if cert.IsCA && !bytes.Equal(cert.RawIssuer, cert.RawSubject) {
    294 			intermediates = append(intermediates, block.Bytes)
    295 		}
    296 	}
    297 	if len(intermediates) == 0 {
    298 		return
    299 	}
    300 
    301 	// Wrap the existing GetClientCertificate to append intermediate certs to
    302 	// the certificate chain returned during the TLS handshake.
    303 	origGetCert := tlsConf.GetClientCertificate
    304 	tlsConf.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
    305 		cert, err := origGetCert(info)
    306 		if err != nil {
    307 			return cert, err
    308 		}
    309 		cert.Certificate = append(cert.Certificate, intermediates...)
    310 		return cert, nil
    311 	}
    312 }