conn.go (48065B)
1 package pq 2 3 import ( 4 "bufio" 5 "context" 6 "crypto/md5" 7 "crypto/sha256" 8 "database/sql" 9 "database/sql/driver" 10 "encoding/binary" 11 "errors" 12 "fmt" 13 "io" 14 "math" 15 "net" 16 "os" 17 "reflect" 18 "strconv" 19 "strings" 20 "sync" 21 "sync/atomic" 22 "time" 23 24 "github.com/lib/pq/internal/pgpass" 25 "github.com/lib/pq/internal/pqsql" 26 "github.com/lib/pq/internal/pqutil" 27 "github.com/lib/pq/internal/proto" 28 "github.com/lib/pq/oid" 29 "github.com/lib/pq/scram" 30 ) 31 32 // Common error types 33 var ( 34 ErrNotSupported = errors.New("pq: unsupported command") 35 ErrInFailedTransaction = errors.New("pq: could not complete operation in a failed transaction") 36 ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") 37 ErrCouldNotDetectUsername = errors.New("pq: could not detect default username; please provide one explicitly") 38 ErrSSLKeyUnknownOwnership = pqutil.ErrSSLKeyUnknownOwnership 39 ErrSSLKeyHasWorldPermissions = pqutil.ErrSSLKeyHasWorldPermissions 40 41 errQueryInProgress = errors.New("pq: there is already a query being processed on this connection") 42 errUnexpectedReady = errors.New("unexpected ReadyForQuery") 43 errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") 44 errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") 45 ) 46 47 // Compile time validation that our types implement the expected interfaces 48 var ( 49 _ driver.Driver = Driver{} 50 _ driver.ConnBeginTx = (*conn)(nil) 51 _ driver.ConnPrepareContext = (*conn)(nil) 52 _ driver.Execer = (*conn)(nil) //lint:ignore SA1019 x 53 _ driver.ExecerContext = (*conn)(nil) 54 _ driver.NamedValueChecker = (*conn)(nil) 55 _ driver.Pinger = (*conn)(nil) 56 _ driver.Queryer = (*conn)(nil) //lint:ignore SA1019 x 57 _ driver.QueryerContext = (*conn)(nil) 58 _ driver.SessionResetter = (*conn)(nil) 59 _ driver.Validator = (*conn)(nil) 60 _ driver.StmtExecContext = (*stmt)(nil) 61 _ driver.StmtQueryContext = (*stmt)(nil) 62 ) 63 64 func init() { 65 sql.Register("postgres", &Driver{}) 66 } 67 68 var debugProto = func() bool { 69 // Check for exactly "1" (rather than mere existence) so we can add 70 // options/flags in the future. I don't know if we ever want that, but it's 71 // nice to leave the option open. 72 return os.Getenv("PQGO_DEBUG") == "1" 73 }() 74 75 // Driver is the Postgres database driver. 76 type Driver struct{} 77 78 // Open opens a new connection to the database. name is a connection string. 79 // Most users should only use it through database/sql package from the standard 80 // library. 81 func (d Driver) Open(name string) (driver.Conn, error) { 82 return Open(name) 83 } 84 85 // Parameters sent by PostgreSQL on startup. 86 type parameterStatus struct { 87 serverVersion int 88 currentLocation *time.Location 89 inHotStandby, defaultTransactionReadOnly sql.NullBool 90 } 91 92 type format int 93 94 const ( 95 formatText format = 0 96 formatBinary format = 1 97 ) 98 99 var ( 100 // One result-column format code with the value 1 (i.e. all binary). 101 colFmtDataAllBinary = []byte{0, 1, 0, 1} 102 103 // No result-column format codes (i.e. all text). 104 colFmtDataAllText = []byte{0, 0} 105 ) 106 107 type transactionStatus byte 108 109 const ( 110 txnStatusIdle transactionStatus = 'I' 111 txnStatusIdleInTransaction transactionStatus = 'T' 112 txnStatusInFailedTransaction transactionStatus = 'E' 113 ) 114 115 func (s transactionStatus) String() string { 116 switch s { 117 case txnStatusIdle: 118 return "idle" 119 case txnStatusIdleInTransaction: 120 return "idle in transaction" 121 case txnStatusInFailedTransaction: 122 return "in a failed transaction" 123 default: 124 panic(fmt.Sprintf("pq: unknown transactionStatus %d", s)) 125 } 126 } 127 128 // Dialer is the dialer interface. It can be used to obtain more control over 129 // how pq creates network connections. 130 type Dialer interface { 131 Dial(network, address string) (net.Conn, error) 132 DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) 133 } 134 135 // DialerContext is the context-aware dialer interface. 136 type DialerContext interface { 137 DialContext(ctx context.Context, network, address string) (net.Conn, error) 138 } 139 140 type defaultDialer struct { 141 d net.Dialer 142 } 143 144 func (d defaultDialer) Dial(network, address string) (net.Conn, error) { 145 return d.d.Dial(network, address) 146 } 147 148 func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { 149 ctx, cancel := context.WithTimeout(context.Background(), timeout) 150 defer cancel() 151 return d.DialContext(ctx, network, address) 152 } 153 154 func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 155 return d.d.DialContext(ctx, network, address) 156 } 157 158 type conn struct { 159 c net.Conn 160 buf *bufio.Reader 161 namei int 162 scratch [512]byte 163 txnStatus transactionStatus 164 txnFinish func() 165 166 // Save connection arguments to use during CancelRequest. 167 dialer Dialer 168 cfg Config 169 parameterStatus parameterStatus 170 171 saveMessageType proto.ResponseCode 172 saveMessageBuffer []byte 173 174 // If an error is set this connection is bad and all public-facing 175 // functions should return the appropriate error by calling get() 176 // (ErrBadConn) or getForNext(). 177 err syncErr 178 179 secretKey []byte // Cancellation key for CancelRequest messages. 180 pid int // Cancellation PID. 181 inProgress atomic.Bool // This connection is in the middle of a processing a request. 182 noticeHandler func(*Error) // If not nil, notices will be synchronously sent here 183 notificationHandler func(*Notification) // If not nil, notifications will be synchronously sent here 184 gss GSS // GSSAPI context 185 } 186 187 type syncErr struct { 188 err error 189 sync.Mutex 190 } 191 192 // Return ErrBadConn if connection is bad. 193 func (e *syncErr) get() error { 194 e.Lock() 195 defer e.Unlock() 196 if e.err != nil { 197 return driver.ErrBadConn 198 } 199 return nil 200 } 201 202 // Return the error set on the connection. Currently only used by rows.Next. 203 func (e *syncErr) getForNext() error { 204 e.Lock() 205 defer e.Unlock() 206 return e.err 207 } 208 209 // Set error, only if it isn't set yet. 210 func (e *syncErr) set(err error) { 211 if err == nil { 212 panic("attempt to set nil err") 213 } 214 e.Lock() 215 defer e.Unlock() 216 if e.err == nil { 217 e.err = err 218 } 219 } 220 221 func (cn *conn) writeBuf(b proto.RequestCode) *writeBuf { 222 cn.scratch[0] = byte(b) 223 return &writeBuf{ 224 buf: cn.scratch[:5], 225 pos: 1, 226 } 227 } 228 229 // Open opens a new connection to the database. dsn is a connection string. Most 230 // users should only use it through database/sql package from the standard 231 // library. 232 func Open(dsn string) (_ driver.Conn, err error) { 233 return DialOpen(defaultDialer{}, dsn) 234 } 235 236 // DialOpen opens a new connection to the database using a dialer. 237 func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { 238 c, err := NewConnector(dsn) 239 if err != nil { 240 return nil, err 241 } 242 c.Dialer(d) 243 return c.open(context.Background()) 244 } 245 246 func (c *Connector) open(ctx context.Context) (*conn, error) { 247 tsa := c.cfg.TargetSessionAttrs 248 restartAll: 249 var ( 250 errs []error 251 app = func(err error, cfg Config) bool { 252 if err != nil { 253 if debugProto { 254 fmt.Fprintln(os.Stderr, "CONNECT (error)", err) 255 } 256 errs = append(errs, fmt.Errorf("connecting to %s:%d: %w", cfg.Host, cfg.Port, err)) 257 } 258 return err != nil 259 } 260 ) 261 for _, cfg := range c.cfg.hosts() { 262 mode := cfg.SSLMode 263 restartHost: 264 if debugProto { 265 fmt.Fprintln(os.Stderr, "CONNECT ", cfg.string()) 266 } 267 268 cfg.SSLMode = mode 269 cn := &conn{cfg: cfg, dialer: c.dialer} 270 cn.cfg.Password = pgpass.PasswordFromPgpass(cn.cfg.Passfile, cn.cfg.User, cn.cfg.Password, 271 cn.cfg.Host, strconv.Itoa(int(cn.cfg.Port)), cn.cfg.Database) 272 273 var err error 274 cn.c, err = dial(ctx, c.dialer, cn.cfg) 275 if app(err, cfg) { 276 continue 277 } 278 279 err = cn.ssl(cn.cfg, mode) 280 if err != nil && mode == SSLModePrefer { 281 mode = SSLModeDisable 282 goto restartHost 283 } 284 if app(err, cfg) { 285 if cn.c != nil { 286 _ = cn.c.Close() 287 } 288 continue 289 } 290 291 cn.buf = bufio.NewReader(cn.c) 292 err = cn.startup(cn.cfg) 293 if err != nil && mode == SSLModeAllow { 294 mode = SSLModeRequire 295 goto restartHost 296 } 297 if app(err, cfg) { 298 _ = cn.c.Close() 299 continue 300 } 301 302 // Reset the deadline, in case one was set (see dial) 303 if cn.cfg.ConnectTimeout > 0 { 304 err := cn.c.SetDeadline(time.Time{}) 305 if app(err, cfg) { 306 _ = cn.c.Close() 307 continue 308 } 309 } 310 311 err = cn.checkTSA(tsa) 312 if app(err, cfg) { 313 _ = cn.c.Close() 314 continue 315 } 316 317 return cn, nil 318 } 319 320 // target_session_attrs=prefer-standby is treated as standby in checkTSA; we 321 // ran out of hosts so none are on standby. Clear the setting and try again. 322 if c.cfg.TargetSessionAttrs == TargetSessionAttrsPreferStandby { 323 tsa = TargetSessionAttrsAny 324 goto restartAll 325 } 326 327 if len(c.cfg.Multi) == 0 { 328 // Remove the "connecting to [..]" when we have just one host, so the 329 // error is identical to what we had before. 330 return nil, errors.Unwrap(errs[0]) 331 } 332 return nil, fmt.Errorf("pq: could not connect to any of the hosts:\n%w", errors.Join(errs...)) 333 } 334 335 func (cn *conn) getBool(query string) (bool, error) { 336 res, err := cn.simpleQuery(query) 337 if err != nil { 338 return false, err 339 } 340 defer res.Close() 341 342 v := make([]driver.Value, 1) 343 err = res.Next(v) 344 if err != nil { 345 return false, err 346 } 347 348 switch vv := v[0].(type) { 349 default: 350 return false, fmt.Errorf("parseBool: unknown type %T: %[1]v", v[0]) 351 case bool: 352 return vv, nil 353 case string: 354 vv, ok := v[0].(string) 355 if !ok { 356 return false, err 357 } 358 return vv == "on", nil 359 } 360 } 361 362 func (cn *conn) checkTSA(tsa TargetSessionAttrs) error { 363 var ( 364 geths = func() (hs bool, err error) { 365 hs = cn.parameterStatus.inHotStandby.Bool 366 if !cn.parameterStatus.inHotStandby.Valid { 367 hs, err = cn.getBool("select pg_catalog.pg_is_in_recovery()") 368 } 369 return hs, err 370 } 371 getro = func() (ro bool, err error) { 372 ro = cn.parameterStatus.defaultTransactionReadOnly.Bool 373 if !cn.parameterStatus.defaultTransactionReadOnly.Valid { 374 ro, err = cn.getBool("show transaction_read_only") 375 } 376 return ro, err 377 } 378 ) 379 380 switch tsa { 381 default: 382 panic("unreachable") 383 case "", TargetSessionAttrsAny: 384 return nil 385 case TargetSessionAttrsReadWrite, TargetSessionAttrsReadOnly: 386 readonly, err := getro() 387 if err != nil { 388 return err 389 } 390 if !cn.parameterStatus.defaultTransactionReadOnly.Valid { 391 var err error 392 readonly, err = cn.getBool("show transaction_read_only") 393 if err != nil { 394 return err 395 } 396 } 397 switch { 398 case tsa == TargetSessionAttrsReadOnly && !readonly: 399 return errors.New("session is not read-only") 400 case tsa == TargetSessionAttrsReadWrite: 401 if readonly { 402 return errors.New("session is read-only") 403 } 404 hs, err := geths() 405 if err != nil { 406 return err 407 } 408 if hs { 409 return errors.New("server is in hot standby mode") 410 } 411 return nil 412 default: 413 return nil 414 } 415 case TargetSessionAttrsPrimary, TargetSessionAttrsStandby, TargetSessionAttrsPreferStandby: 416 hs, err := geths() 417 if err != nil { 418 return err 419 } 420 switch { 421 case (tsa == TargetSessionAttrsStandby || tsa == TargetSessionAttrsPreferStandby) && !hs: 422 return errors.New("server is not in hot standby mode") 423 case tsa == TargetSessionAttrsPrimary && hs: 424 return errors.New("server is in hot standby mode") 425 default: 426 return nil 427 } 428 } 429 } 430 431 func dial(ctx context.Context, d Dialer, cfg Config) (net.Conn, error) { 432 network, address := cfg.network() 433 434 // Zero or not specified means wait indefinitely. 435 if cfg.ConnectTimeout > 0 { 436 // connect_timeout should apply to the entire connection establishment 437 // procedure, so we both use a timeout for the TCP connection 438 // establishment and set a deadline for doing the initial handshake. The 439 // deadline is then reset after startup() is done. 440 var ( 441 deadline = time.Now().Add(cfg.ConnectTimeout) 442 conn net.Conn 443 err error 444 ) 445 if dctx, ok := d.(DialerContext); ok { 446 ctx, cancel := context.WithTimeout(ctx, cfg.ConnectTimeout) 447 defer cancel() 448 conn, err = dctx.DialContext(ctx, network, address) 449 } else { 450 conn, err = d.DialTimeout(network, address, cfg.ConnectTimeout) 451 } 452 if err != nil { 453 return nil, err 454 } 455 err = conn.SetDeadline(deadline) 456 return conn, err 457 } 458 if dctx, ok := d.(DialerContext); ok { 459 return dctx.DialContext(ctx, network, address) 460 } 461 return d.Dial(network, address) 462 } 463 464 func (cn *conn) isInTransaction() bool { 465 return cn.txnStatus == txnStatusIdleInTransaction || 466 cn.txnStatus == txnStatusInFailedTransaction 467 } 468 469 func (cn *conn) checkIsInTransaction(intxn bool) error { 470 if cn.isInTransaction() != intxn { 471 cn.err.set(driver.ErrBadConn) 472 return fmt.Errorf("pq: unexpected transaction status %v", cn.txnStatus) 473 } 474 return nil 475 } 476 477 func (cn *conn) Begin() (_ driver.Tx, err error) { 478 return cn.begin("") 479 } 480 481 func (cn *conn) begin(mode string) (_ driver.Tx, err error) { 482 if err := cn.err.get(); err != nil { 483 return nil, err 484 } 485 if err := cn.checkIsInTransaction(false); err != nil { 486 return nil, err 487 } 488 489 _, commandTag, err := cn.simpleExec("BEGIN" + mode) 490 if err != nil { 491 return nil, cn.handleError(err) 492 } 493 if commandTag != "BEGIN" { 494 cn.err.set(driver.ErrBadConn) 495 return nil, fmt.Errorf("unexpected command tag %s", commandTag) 496 } 497 if cn.txnStatus != txnStatusIdleInTransaction { 498 cn.err.set(driver.ErrBadConn) 499 return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) 500 } 501 return cn, nil 502 } 503 504 func (cn *conn) closeTxn() { 505 if finish := cn.txnFinish; finish != nil { 506 finish() 507 } 508 } 509 510 func (cn *conn) Commit() error { 511 defer cn.closeTxn() 512 if err := cn.err.get(); err != nil { 513 return err 514 } 515 if err := cn.checkIsInTransaction(true); err != nil { 516 return err 517 } 518 519 // We don't want the client to think that everything is okay if it tries 520 // to commit a failed transaction. However, no matter what we return, 521 // database/sql will release this connection back into the free connection 522 // pool so we have to abort the current transaction here. Note that you 523 // would get the same behaviour if you issued a COMMIT in a failed 524 // transaction, so it's also the least surprising thing to do here. 525 if cn.txnStatus == txnStatusInFailedTransaction { 526 if err := cn.rollback(); err != nil { 527 return err 528 } 529 return ErrInFailedTransaction 530 } 531 532 _, commandTag, err := cn.simpleExec("COMMIT") 533 if err != nil { 534 if cn.isInTransaction() { 535 cn.err.set(driver.ErrBadConn) 536 } 537 return cn.handleError(err) 538 } 539 if commandTag != "COMMIT" { 540 cn.err.set(driver.ErrBadConn) 541 return fmt.Errorf("unexpected command tag %s", commandTag) 542 } 543 return cn.checkIsInTransaction(false) 544 } 545 546 func (cn *conn) Rollback() error { 547 defer cn.closeTxn() 548 if err := cn.err.get(); err != nil { 549 return err 550 } 551 552 err := cn.rollback() 553 if err != nil { 554 return cn.handleError(err) 555 } 556 return nil 557 } 558 559 func (cn *conn) rollback() (err error) { 560 if err := cn.checkIsInTransaction(true); err != nil { 561 return err 562 } 563 564 _, commandTag, err := cn.simpleExec("ROLLBACK") 565 if err != nil { 566 if cn.isInTransaction() { 567 cn.err.set(driver.ErrBadConn) 568 } 569 return err 570 } 571 if commandTag != "ROLLBACK" { 572 return fmt.Errorf("unexpected command tag %s", commandTag) 573 } 574 return cn.checkIsInTransaction(false) 575 } 576 577 func (cn *conn) gname() string { 578 cn.namei++ 579 return strconv.FormatInt(int64(cn.namei), 10) 580 } 581 582 func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, resErr error) { 583 if debugProto { 584 fmt.Fprintln(os.Stderr, " START conn.simpleExec") 585 defer fmt.Fprintln(os.Stderr, " END conn.simpleExec") 586 } 587 588 b := cn.writeBuf(proto.Query) 589 b.string(q) 590 err := cn.send(b) 591 if err != nil { 592 return nil, "", err 593 } 594 595 for { 596 t, r, err := cn.recv1() 597 if err != nil { 598 return nil, "", err 599 } 600 switch t { 601 case proto.CommandComplete: 602 res, commandTag, err = cn.parseComplete(r.string()) 603 if err != nil { 604 return nil, "", err 605 } 606 case proto.ReadyForQuery: 607 cn.processReadyForQuery(r) 608 if res == nil && resErr == nil { 609 resErr = errUnexpectedReady 610 } 611 return res, commandTag, resErr 612 case proto.ErrorResponse: 613 resErr = parseError(r, q) 614 case proto.EmptyQueryResponse: 615 res = emptyRows 616 case proto.RowDescription, proto.DataRow: 617 // ignore any results 618 default: 619 cn.err.set(driver.ErrBadConn) 620 return nil, "", fmt.Errorf("pq: unknown response for simple query: %q", t) 621 } 622 } 623 } 624 625 func (cn *conn) simpleQuery(q string) (*rows, error) { 626 if debugProto { 627 fmt.Fprintln(os.Stderr, " START conn.simpleQuery") 628 defer fmt.Fprintln(os.Stderr, " END conn.simpleQuery") 629 } 630 631 b := cn.writeBuf(proto.Query) 632 b.string(q) 633 err := cn.send(b) 634 if err != nil { 635 return nil, cn.handleError(err, q) 636 } 637 638 var ( 639 res *rows 640 resErr error 641 ) 642 for { 643 t, r, err := cn.recv1() 644 if err != nil { 645 return nil, cn.handleError(err, q) 646 } 647 switch t { 648 case proto.CommandComplete, proto.EmptyQueryResponse: 649 // We allow queries which don't return any results through Query as 650 // well as Exec. We still have to give database/sql a rows object 651 // the user can close, though, to avoid connections from being 652 // leaked. A "rows" with done=true works fine for that purpose. 653 if resErr != nil { 654 cn.err.set(driver.ErrBadConn) 655 return nil, fmt.Errorf("pq: unexpected message %q in simple query execution", t) 656 } 657 if res == nil { 658 res = &rows{cn: cn} 659 } 660 // Set the result and tag to the last command complete if there wasn't a 661 // query already run. Although queries usually return from here and cede 662 // control to Next, a query with zero results does not. 663 if t == proto.CommandComplete { 664 res.result, res.tag, err = cn.parseComplete(r.string()) 665 if err != nil { 666 return nil, cn.handleError(err, q) 667 } 668 if res.colNames != nil { 669 return res, cn.handleError(resErr, q) 670 } 671 } 672 res.done = true 673 case proto.ReadyForQuery: 674 cn.processReadyForQuery(r) 675 if err == nil && res == nil { 676 res = &rows{done: true} 677 } 678 return res, cn.handleError(resErr, q) // done 679 case proto.ErrorResponse: 680 res = nil 681 resErr = parseError(r, q) 682 case proto.DataRow: 683 if res == nil { 684 cn.err.set(driver.ErrBadConn) 685 return nil, fmt.Errorf("pq: unexpected DataRow in simple query execution") 686 } 687 return res, cn.saveMessage(t, r) // The query didn't fail; kick off to Next 688 case proto.RowDescription: 689 // res might be non-nil here if we received a previous 690 // CommandComplete, but that's fine and just overwrite it. 691 res = &rows{cn: cn, rowsHeader: parsePortalRowDescribe(r)} 692 693 // To work around a bug in QueryRow in Go 1.2 and earlier, wait 694 // until the first DataRow has been received. 695 default: 696 cn.err.set(driver.ErrBadConn) 697 return nil, fmt.Errorf("pq: unknown response for simple query: %q", t) 698 } 699 } 700 } 701 702 // Decides which column formats to use for a prepared statement. The input is 703 // an array of type oids, one element per result column. 704 func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte, _ error) { 705 if len(colTyps) == 0 { 706 return nil, colFmtDataAllText, nil 707 } 708 709 colFmts = make([]format, len(colTyps)) 710 if forceText { 711 return colFmts, colFmtDataAllText, nil 712 } 713 714 allBinary := true 715 allText := true 716 for i, t := range colTyps { 717 switch t.OID { 718 // This is the list of types to use binary mode for when receiving them 719 // through a prepared statement. If a type appears in this list, it 720 // must also be implemented in binaryDecode in encode.go. 721 case oid.T_bytea: 722 fallthrough 723 case oid.T_int8: 724 fallthrough 725 case oid.T_int4: 726 fallthrough 727 case oid.T_int2: 728 fallthrough 729 case oid.T_uuid: 730 colFmts[i] = formatBinary 731 allText = false 732 default: 733 allBinary = false 734 } 735 } 736 737 if allBinary { 738 return colFmts, colFmtDataAllBinary, nil 739 } else if allText { 740 return colFmts, colFmtDataAllText, nil 741 } else { 742 colFmtData = make([]byte, 2+len(colFmts)*2) 743 if len(colFmts) > math.MaxUint16 { 744 return nil, nil, fmt.Errorf("pq: too many columns (%d > math.MaxUint16)", len(colFmts)) 745 } 746 binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) 747 for i, v := range colFmts { 748 binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) 749 } 750 return colFmts, colFmtData, nil 751 } 752 } 753 754 func (cn *conn) prepareTo(q, stmtName string) (*stmt, error) { 755 if debugProto { 756 fmt.Fprintln(os.Stderr, " START conn.prepareTo") 757 defer fmt.Fprintln(os.Stderr, " END conn.prepareTo") 758 } 759 760 st := &stmt{cn: cn, name: stmtName} 761 762 b := cn.writeBuf(proto.Parse) 763 b.string(st.name) 764 b.string(q) 765 b.int16(0) 766 767 b.next(proto.Describe) 768 b.byte(proto.Sync) 769 b.string(st.name) 770 771 b.next(proto.Sync) 772 err := cn.send(b) 773 if err != nil { 774 return nil, err 775 } 776 777 err = cn.readParseResponse() 778 if err != nil { 779 return nil, err 780 } 781 st.paramTyps, st.colNames, st.colTyps, err = cn.readStatementDescribeResponse() 782 if err != nil { 783 return nil, err 784 } 785 st.colFmts, st.colFmtData, err = decideColumnFormats(st.colTyps, cn.cfg.DisablePreparedBinaryResult) 786 if err != nil { 787 return nil, err 788 } 789 790 err = cn.readReadyForQuery() 791 if err != nil { 792 return nil, err 793 } 794 return st, nil 795 } 796 797 func (cn *conn) Prepare(q string) (driver.Stmt, error) { 798 if err := cn.err.get(); err != nil { 799 return nil, err 800 } 801 802 if pqsql.StartsWithCopy(q) { 803 s, err := cn.prepareCopyIn(q) 804 if err == nil { 805 cn.inProgress.Store(true) 806 } 807 return s, cn.handleError(err, q) 808 } 809 s, err := cn.prepareTo(q, cn.gname()) 810 if err != nil { 811 return nil, cn.handleError(err, q) 812 } 813 return s, nil 814 } 815 816 func (cn *conn) Close() error { 817 // Don't go through send(); ListenerConn relies on us not scribbling on the 818 // scratch buffer of this connection. 819 err := cn.sendSimpleMessage(proto.Terminate) 820 if err != nil { 821 _ = cn.c.Close() // Ensure that cn.c.Close is always run. 822 return cn.handleError(err) 823 } 824 return cn.c.Close() 825 } 826 827 func toNamedValue(v []driver.Value) []driver.NamedValue { 828 v2 := make([]driver.NamedValue, len(v)) 829 for i := range v { 830 v2[i] = driver.NamedValue{Ordinal: i + 1, Value: v[i]} 831 } 832 return v2 833 } 834 835 // CheckNamedValue implements [driver.NamedValueChecker]. 836 func (cn *conn) CheckNamedValue(nv *driver.NamedValue) error { 837 if cn.cfg.BinaryParameters { 838 if bin, ok := nv.Value.(interface{ BinaryValue() ([]byte, error) }); ok { 839 var err error 840 nv.Value, err = bin.BinaryValue() 841 return err 842 } 843 } 844 845 // Ignore Valuer, for backward compatibility with pq.Array(). 846 if _, ok := nv.Value.(driver.Valuer); ok { 847 return driver.ErrSkip 848 } 849 850 v := reflect.ValueOf(nv.Value) 851 if !v.IsValid() { 852 return driver.ErrSkip 853 } 854 t := v.Type() 855 for t.Kind() == reflect.Pointer { 856 t, v = t.Elem(), v.Elem() 857 } 858 859 // Ignore []byte and related types: *[]byte, json.RawMessage, etc. 860 if t.Kind() == reflect.Slice && t.Elem().Kind() == reflect.Uint8 { 861 return driver.ErrSkip 862 } 863 864 switch v.Kind() { 865 default: 866 return driver.ErrSkip 867 case reflect.Slice: 868 var err error 869 nv.Value, err = Array(v.Interface()).Value() 870 return err 871 case reflect.Uint64: 872 value := v.Uint() 873 if value >= math.MaxInt64 { 874 nv.Value = strconv.FormatUint(value, 10) 875 } else { 876 nv.Value = int64(value) 877 } 878 return nil 879 } 880 } 881 882 // Implement the "Queryer" interface 883 func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { 884 return cn.query(query, toNamedValue(args)) 885 } 886 887 func (cn *conn) query(query string, args []driver.NamedValue) (*rows, error) { 888 if debugProto { 889 fmt.Fprintln(os.Stderr, " START conn.query") 890 defer fmt.Fprintln(os.Stderr, " END conn.query") 891 } 892 if err := cn.err.get(); err != nil { 893 return nil, err 894 } 895 if !cn.inProgress.CompareAndSwap(false, true) { 896 return nil, errQueryInProgress 897 } 898 899 // Check to see if we can use the "simpleQuery" interface, which is 900 // *much* faster than going through prepare/exec 901 if len(args) == 0 { 902 return cn.simpleQuery(query) 903 } 904 905 if cn.cfg.BinaryParameters { 906 err := cn.sendBinaryModeQuery(query, args) 907 if err != nil { 908 return nil, cn.handleError(err, query) 909 } 910 err = cn.readParseResponse() 911 if err != nil { 912 return nil, cn.handleError(err, query) 913 } 914 err = cn.readBindResponse() 915 if err != nil { 916 return nil, cn.handleError(err, query) 917 } 918 919 rows := &rows{cn: cn} 920 rows.rowsHeader, err = cn.readPortalDescribeResponse() 921 if err != nil { 922 return nil, cn.handleError(err, query) 923 } 924 err = cn.postExecuteWorkaround() 925 if err != nil { 926 return nil, cn.handleError(err, query) 927 } 928 return rows, nil 929 } 930 931 st, err := cn.prepareTo(query, "") 932 if err != nil { 933 return nil, cn.handleError(err, query) 934 } 935 err = st.exec(args) 936 if err != nil { 937 return nil, cn.handleError(err, query) 938 } 939 return &rows{ 940 cn: cn, 941 rowsHeader: st.rowsHeader, 942 }, nil 943 } 944 945 // Implement the optional "Execer" interface for one-shot queries 946 func (cn *conn) Exec(query string, args []driver.Value) (driver.Result, error) { 947 if err := cn.err.get(); err != nil { 948 return nil, err 949 } 950 if !cn.inProgress.CompareAndSwap(false, true) { 951 return nil, errQueryInProgress 952 } 953 954 // Check to see if we can use the "simpleExec" interface, which is *much* 955 // faster than going through prepare/exec 956 if len(args) == 0 { 957 // ignore commandTag, our caller doesn't care 958 r, _, err := cn.simpleExec(query) 959 return r, cn.handleError(err, query) 960 } 961 962 if cn.cfg.BinaryParameters { 963 err := cn.sendBinaryModeQuery(query, toNamedValue(args)) 964 if err != nil { 965 return nil, cn.handleError(err, query) 966 } 967 err = cn.readParseResponse() 968 if err != nil { 969 return nil, cn.handleError(err, query) 970 } 971 err = cn.readBindResponse() 972 if err != nil { 973 return nil, cn.handleError(err, query) 974 } 975 976 _, err = cn.readPortalDescribeResponse() 977 if err != nil { 978 return nil, cn.handleError(err, query) 979 } 980 err = cn.postExecuteWorkaround() 981 if err != nil { 982 return nil, cn.handleError(err, query) 983 } 984 res, _, err := cn.readExecuteResponse("Execute") 985 return res, cn.handleError(err, query) 986 } 987 988 // Use the unnamed statement to defer planning until bind time, or else 989 // value-based selectivity estimates cannot be used. 990 st, err := cn.prepareTo(query, "") 991 if err != nil { 992 return nil, cn.handleError(err, query) 993 } 994 r, err := st.Exec(args) 995 if err != nil { 996 return nil, cn.handleError(err, query) 997 } 998 return r, nil 999 } 1000 1001 type safeRetryError struct{ Err error } 1002 1003 func (se *safeRetryError) Error() string { return se.Err.Error() } 1004 1005 func (cn *conn) send(m *writeBuf) error { 1006 if debugProto { 1007 w := m.wrap() 1008 for len(w) > 0 { // Can contain multiple messages. 1009 c := proto.RequestCode(w[0]) 1010 l := int(binary.BigEndian.Uint32(w[1:5])) - 4 1011 fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", c, l, w[5:l+5]) 1012 w = w[l+5:] 1013 } 1014 } 1015 1016 n, err := cn.c.Write(m.wrap()) 1017 if err != nil && n == 0 { 1018 err = &safeRetryError{Err: err} 1019 } 1020 return err 1021 } 1022 1023 func (cn *conn) sendStartupPacket(m *writeBuf) error { 1024 if debugProto { 1025 w := m.wrap() 1026 fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", "Startup", int(binary.BigEndian.Uint32(w[1:5]))-4, w[5:]) 1027 } 1028 _, err := cn.c.Write((m.wrap())[1:]) 1029 return err 1030 } 1031 1032 // Send a message of type typ to the server on the other end of cn. The message 1033 // should have no payload. This method does not use the scratch buffer. 1034 func (cn *conn) sendSimpleMessage(typ proto.RequestCode) error { 1035 if debugProto { 1036 fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", typ, 0, []byte{}) 1037 } 1038 _, err := cn.c.Write([]byte{byte(typ), '\x00', '\x00', '\x00', '\x04'}) 1039 return err 1040 } 1041 1042 // saveMessage memorizes a message and its buffer in the conn struct. 1043 // recvMessage will then return these values on the next call to it. This 1044 // method is useful in cases where you have to see what the next message is 1045 // going to be (e.g. to see whether it's an error or not) but you can't handle 1046 // the message yourself. 1047 func (cn *conn) saveMessage(typ proto.ResponseCode, buf *readBuf) error { 1048 if cn.saveMessageType != 0 { 1049 cn.err.set(driver.ErrBadConn) 1050 return fmt.Errorf("unexpected saveMessageType %d", cn.saveMessageType) 1051 } 1052 cn.saveMessageType = typ 1053 cn.saveMessageBuffer = *buf 1054 return nil 1055 } 1056 1057 // recvMessage receives any message from the backend, or returns an error if 1058 // a problem occurred while reading the message. 1059 func (cn *conn) recvMessage(r *readBuf) (proto.ResponseCode, error) { 1060 // workaround for a QueryRow bug, see exec 1061 if cn.saveMessageType != 0 { 1062 t := cn.saveMessageType 1063 *r = cn.saveMessageBuffer 1064 cn.saveMessageType = 0 1065 cn.saveMessageBuffer = nil 1066 return t, nil 1067 } 1068 1069 x := cn.scratch[:5] 1070 _, err := io.ReadFull(cn.buf, x) 1071 if err != nil { 1072 return 0, err 1073 } 1074 1075 // Read the type and length of the message that follows. 1076 t := proto.ResponseCode(x[0]) 1077 n := int(binary.BigEndian.Uint32(x[1:])) - 4 1078 1079 if proto.ResponseCode(t) == proto.ReadyForQuery { 1080 cn.inProgress.Store(false) 1081 } 1082 1083 // When PostgreSQL cannot start a backend (e.g., an external process limit), 1084 // it sends plain text like "Ecould not fork new process [..]", which 1085 // doesn't use the standard encoding for the Error message. 1086 // 1087 // libpq checks "if ErrorResponse && (msgLength < 8 || msgLength > MAX_ERRLEN)", 1088 // but check < 4 since n represents bytes remaining to be read after length. 1089 if t == proto.ErrorResponse && (n < 4 || n > proto.MaxErrlen) { 1090 msg, _ := cn.buf.ReadString('\x00') 1091 return 0, fmt.Errorf("pq: server error: %s%s", string(x[1:]), strings.TrimSuffix(msg, "\x00")) 1092 } 1093 1094 var y []byte 1095 if n <= len(cn.scratch) { 1096 y = cn.scratch[:n] 1097 } else { 1098 y = make([]byte, n) 1099 } 1100 _, err = io.ReadFull(cn.buf, y) 1101 if err != nil { 1102 return 0, err 1103 } 1104 *r = y 1105 if debugProto { 1106 fmt.Fprintf(os.Stderr, "SERVER ← %-20s %5d %q\n", t, n, y) 1107 } 1108 return t, nil 1109 } 1110 1111 // recv receives a message from the backend, returning an error if an error 1112 // happened while reading the message or the received message an ErrorResponse. 1113 // NoticeResponses are ignored. This function should generally be used only 1114 // during the startup sequence. 1115 func (cn *conn) recv() (proto.ResponseCode, *readBuf, error) { 1116 for { 1117 r := new(readBuf) 1118 t, err := cn.recvMessage(r) 1119 if err != nil { 1120 return 0, nil, err 1121 } 1122 switch t { 1123 case proto.ErrorResponse: 1124 return 0, nil, parseError(r, "") 1125 case proto.NoticeResponse: 1126 if n := cn.noticeHandler; n != nil { 1127 n(parseError(r, "")) 1128 } 1129 case proto.NotificationResponse: 1130 if n := cn.notificationHandler; n != nil { 1131 n(recvNotification(r)) 1132 } 1133 default: 1134 return t, r, nil 1135 } 1136 } 1137 } 1138 1139 // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by 1140 // the caller to avoid an allocation. 1141 func (cn *conn) recv1Buf(r *readBuf) (proto.ResponseCode, error) { 1142 for { 1143 t, err := cn.recvMessage(r) 1144 if err != nil { 1145 return 0, err 1146 } 1147 1148 switch t { 1149 case proto.NotificationResponse: 1150 if n := cn.notificationHandler; n != nil { 1151 n(recvNotification(r)) 1152 } 1153 case proto.NoticeResponse: 1154 if n := cn.noticeHandler; n != nil { 1155 n(parseError(r, "")) 1156 } 1157 case proto.ParameterStatus: 1158 cn.processParameterStatus(r) 1159 default: 1160 return t, nil 1161 } 1162 } 1163 } 1164 1165 // recv1 receives a message from the backend, returning an error if an error 1166 // happened while reading the message or the received message an ErrorResponse. 1167 // All asynchronous messages are ignored, with the exception of ErrorResponse. 1168 func (cn *conn) recv1() (proto.ResponseCode, *readBuf, error) { 1169 r := new(readBuf) 1170 t, err := cn.recv1Buf(r) 1171 if err != nil { 1172 return 0, nil, err 1173 } 1174 return t, r, nil 1175 } 1176 1177 // Don't refer to Config.SSLMode here, as the mode in arguments may be different 1178 // in case of sslmode=allow or prefer. 1179 func (cn *conn) ssl(cfg Config, mode SSLMode) error { 1180 upgrade, err := ssl(cfg, mode) 1181 if err != nil { 1182 return err 1183 } 1184 if upgrade == nil { 1185 return nil // Nothing to do 1186 } 1187 1188 // Only negotiate the ssl handshake if requested (which is the default). 1189 // sslnegotiation=direct is supported by pg17 and above. 1190 if cfg.SSLNegotiation != SSLNegotiationDirect { 1191 w := cn.writeBuf(0) 1192 w.int32(proto.NegotiateSSLCode) 1193 if err = cn.sendStartupPacket(w); err != nil { 1194 return err 1195 } 1196 1197 b := cn.scratch[:1] 1198 _, err = io.ReadFull(cn.c, b) 1199 if err != nil { 1200 return err 1201 } 1202 1203 if b[0] != 'S' { 1204 return ErrSSLNotSupported 1205 } 1206 } 1207 1208 cn.c, err = upgrade(cn.c) 1209 return err 1210 } 1211 1212 func (cn *conn) startup(cfg Config) error { 1213 w := cn.writeBuf(0) 1214 // Send maximum protocol version in startup; if the server doesn't support 1215 // this version it responds with NegotiateProtocolVersion and the maximum 1216 // version it supports (and will use). 1217 w.int32(cfg.MaxProtocolVersion.proto()) 1218 1219 if cfg.User != "" { 1220 w.string("user") 1221 w.string(cfg.User) 1222 } 1223 if cfg.Database != "" { 1224 w.string("database") 1225 w.string(cfg.Database) 1226 } 1227 // w.string("replication") // Sent by libpq, but we don't support that. 1228 if cfg.Options != "" { 1229 w.string("options") 1230 w.string(cfg.Options) 1231 } 1232 if cfg.ApplicationName != "" { 1233 w.string("application_name") 1234 w.string(cfg.ApplicationName) 1235 } 1236 if cfg.ClientEncoding != "" { 1237 w.string("client_encoding") 1238 w.string(cfg.ClientEncoding) 1239 } 1240 if cfg.Datestyle != "" { 1241 w.string("datestyle") 1242 w.string(cfg.Datestyle) 1243 } 1244 for k, v := range cfg.Runtime { 1245 w.string(k) 1246 w.string(v) 1247 } 1248 1249 w.string("") 1250 if err := cn.sendStartupPacket(w); err != nil { 1251 return err 1252 } 1253 1254 for { 1255 t, r, err := cn.recv() 1256 if err != nil { 1257 return err 1258 } 1259 switch t { 1260 case proto.BackendKeyData: 1261 cn.pid = r.int32() 1262 if len(*r) > 256 { 1263 return fmt.Errorf("pq: cancellation key longer than 256 bytes: %d bytes", len(*r)) 1264 } 1265 cn.secretKey = make([]byte, len(*r)) 1266 copy(cn.secretKey, *r) 1267 case proto.ParameterStatus: 1268 cn.processParameterStatus(r) 1269 case proto.AuthenticationRequest: 1270 err := cn.auth(r, cfg) 1271 if err != nil { 1272 return err 1273 } 1274 case proto.NegotiateProtocolVersion: 1275 newestMinor := r.int32() 1276 serverVersion := proto.ProtocolVersion30&0xFFFF0000 | newestMinor 1277 if serverVersion < cfg.MinProtocolVersion.proto() { 1278 return fmt.Errorf("pq: protocol version mismatch: min_protocol_version=%s; server supports up to 3.%d", cfg.MinProtocolVersion, newestMinor) 1279 } 1280 case proto.ReadyForQuery: 1281 cn.processReadyForQuery(r) 1282 return nil 1283 default: 1284 return fmt.Errorf("pq: unknown response for startup: %q", t) 1285 } 1286 } 1287 } 1288 1289 func (cn *conn) auth(r *readBuf, cfg Config) error { 1290 switch code := proto.AuthCode(r.int32()); code { 1291 default: 1292 return fmt.Errorf("pq: unknown authentication response: %s", code) 1293 case proto.AuthReqKrb4, proto.AuthReqKrb5, proto.AuthReqCrypt, proto.AuthReqSSPI: 1294 return fmt.Errorf("pq: unsupported authentication method: %s", code) 1295 case proto.AuthReqOk: 1296 return nil 1297 1298 case proto.AuthReqPassword: 1299 w := cn.writeBuf(proto.PasswordMessage) 1300 w.string(cfg.Password) 1301 // Don't need to check AuthOk response here; auth() is called in a loop, 1302 // which catches the errors and AuthReqOk responses. 1303 return cn.send(w) 1304 1305 case proto.AuthReqMD5: 1306 s := string(r.next(4)) 1307 w := cn.writeBuf(proto.PasswordMessage) 1308 w.string("md5" + md5s(md5s(cfg.Password+cfg.User)+s)) 1309 // Same here. 1310 return cn.send(w) 1311 1312 case proto.AuthReqGSS: // GSSAPI, startup 1313 if newGss == nil { 1314 return fmt.Errorf("pq: kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos)") 1315 } 1316 cli, err := newGss() 1317 if err != nil { 1318 return fmt.Errorf("pq: kerberos error: %w", err) 1319 } 1320 1321 var token []byte 1322 if cfg.KrbSpn != "" { 1323 // Use the supplied SPN if provided. 1324 token, err = cli.GetInitTokenFromSpn(cfg.KrbSpn) 1325 } else { 1326 // Allow the kerberos service name to be overridden. 1327 service := "postgres" 1328 if cfg.KrbSrvname != "" { 1329 service = cfg.KrbSrvname 1330 } 1331 token, err = cli.GetInitToken(cfg.Host, service) 1332 } 1333 if err != nil { 1334 return fmt.Errorf("pq: failed to get Kerberos ticket: %w", err) 1335 } 1336 1337 w := cn.writeBuf(proto.GSSResponse) 1338 w.bytes(token) 1339 err = cn.send(w) 1340 if err != nil { 1341 return err 1342 } 1343 1344 // Store for GSSAPI continue message 1345 cn.gss = cli 1346 return nil 1347 1348 case proto.AuthReqGSSCont: // GSSAPI continue 1349 if cn.gss == nil { 1350 return errors.New("pq: GSSAPI protocol error") 1351 } 1352 1353 done, tokOut, err := cn.gss.Continue([]byte(*r)) 1354 if err == nil && !done { 1355 w := cn.writeBuf(proto.SASLInitialResponse) 1356 w.bytes(tokOut) 1357 err = cn.send(w) 1358 if err != nil { 1359 return err 1360 } 1361 } 1362 1363 // Errors fall through and read the more detailed message from the 1364 // server. 1365 return nil 1366 1367 case proto.AuthReqSASL: 1368 sc := scram.NewClient(sha256.New, cfg.User, cfg.Password) 1369 sc.Step(nil) 1370 if sc.Err() != nil { 1371 return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) 1372 } 1373 scOut := sc.Out() 1374 1375 w := cn.writeBuf(proto.SASLResponse) 1376 w.string("SCRAM-SHA-256") 1377 w.int32(len(scOut)) 1378 w.bytes(scOut) 1379 err := cn.send(w) 1380 if err != nil { 1381 return err 1382 } 1383 1384 t, r, err := cn.recv() 1385 if err != nil { 1386 return err 1387 } 1388 if t != proto.AuthenticationRequest { 1389 return fmt.Errorf("pq: unexpected password response: %q", t) 1390 } 1391 1392 if r.int32() != int(proto.AuthReqSASLCont) { 1393 return fmt.Errorf("pq: unexpected authentication response: %q", t) 1394 } 1395 1396 nextStep := r.next(len(*r)) 1397 sc.Step(nextStep) 1398 if sc.Err() != nil { 1399 return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) 1400 } 1401 1402 scOut = sc.Out() 1403 w = cn.writeBuf(proto.SASLResponse) 1404 w.bytes(scOut) 1405 err = cn.send(w) 1406 if err != nil { 1407 return err 1408 } 1409 1410 t, r, err = cn.recv() 1411 if err != nil { 1412 return err 1413 } 1414 if t != proto.AuthenticationRequest { 1415 return fmt.Errorf("pq: unexpected password response: %q", t) 1416 } 1417 1418 if r.int32() != int(proto.AuthReqSASLFin) { 1419 return fmt.Errorf("pq: unexpected authentication response: %q", t) 1420 } 1421 1422 nextStep = r.next(len(*r)) 1423 sc.Step(nextStep) 1424 if sc.Err() != nil { 1425 return fmt.Errorf("pq: SCRAM-SHA-256 error: %w", sc.Err()) 1426 } 1427 1428 return nil 1429 } 1430 } 1431 1432 // parseComplete parses the "command tag" from a CommandComplete message, and 1433 // returns the number of rows affected (if applicable) and a string identifying 1434 // only the command that was executed, e.g. "ALTER TABLE". Returns an error if 1435 // the command can cannot be parsed. 1436 func (cn *conn) parseComplete(commandTag string) (driver.Result, string, error) { 1437 commandsWithAffectedRows := []string{ 1438 "SELECT ", 1439 // INSERT is handled below 1440 "UPDATE ", 1441 "DELETE ", 1442 "FETCH ", 1443 "MOVE ", 1444 "COPY ", 1445 } 1446 1447 var affectedRows *string 1448 for _, tag := range commandsWithAffectedRows { 1449 if strings.HasPrefix(commandTag, tag) { 1450 t := commandTag[len(tag):] 1451 affectedRows = &t 1452 commandTag = tag[:len(tag)-1] 1453 break 1454 } 1455 } 1456 // INSERT also includes the oid of the inserted row in its command tag. Oids 1457 // in user tables are deprecated, and the oid is only returned when exactly 1458 // one row is inserted, so it's unlikely to be of value to any real-world 1459 // application and we can ignore it. 1460 if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { 1461 parts := strings.Split(commandTag, " ") 1462 if len(parts) != 3 { 1463 cn.err.set(driver.ErrBadConn) 1464 return nil, "", fmt.Errorf("pq: unexpected INSERT command tag %s", commandTag) 1465 } 1466 affectedRows = &parts[len(parts)-1] 1467 commandTag = "INSERT" 1468 } 1469 // There should be no affected rows attached to the tag, just return it 1470 if affectedRows == nil { 1471 return driver.RowsAffected(0), commandTag, nil 1472 } 1473 n, err := strconv.ParseInt(*affectedRows, 10, 64) 1474 if err != nil { 1475 cn.err.set(driver.ErrBadConn) 1476 return nil, "", fmt.Errorf("pq: could not parse commandTag: %w", err) 1477 } 1478 return driver.RowsAffected(n), commandTag, nil 1479 } 1480 1481 func md5s(s string) string { 1482 h := md5.New() 1483 h.Write([]byte(s)) 1484 return fmt.Sprintf("%x", h.Sum(nil)) 1485 } 1486 1487 func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.NamedValue) error { 1488 // Do one pass over the parameters to see if we're going to send any of them 1489 // over in binary. If we are, create a paramFormats array at the same time. 1490 var paramFormats []int 1491 for i, x := range args { 1492 _, ok := x.Value.([]byte) 1493 if ok { 1494 if paramFormats == nil { 1495 paramFormats = make([]int, len(args)) 1496 } 1497 paramFormats[i] = 1 1498 } 1499 } 1500 if paramFormats == nil { 1501 b.int16(0) 1502 } else { 1503 b.int16(len(paramFormats)) 1504 for _, x := range paramFormats { 1505 b.int16(x) 1506 } 1507 } 1508 1509 b.int16(len(args)) 1510 for _, x := range args { 1511 if x.Value == nil { 1512 b.int32(-1) 1513 } else if xx, ok := x.Value.([]byte); ok && xx == nil { 1514 b.int32(-1) 1515 } else { 1516 datum, err := binaryEncode(x.Value) 1517 if err != nil { 1518 return err 1519 } 1520 b.int32(len(datum)) 1521 b.bytes(datum) 1522 } 1523 } 1524 return nil 1525 } 1526 1527 func (cn *conn) sendBinaryModeQuery(query string, args []driver.NamedValue) error { 1528 if len(args) >= 65536 { 1529 return fmt.Errorf("pq: got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) 1530 } 1531 1532 b := cn.writeBuf(proto.Parse) 1533 b.byte(0) // unnamed statement 1534 b.string(query) 1535 b.int16(0) 1536 1537 b.next(proto.Bind) 1538 b.int16(0) // unnamed portal and statement 1539 err := cn.sendBinaryParameters(b, args) 1540 if err != nil { 1541 return err 1542 } 1543 b.bytes(colFmtDataAllText) 1544 1545 b.next(proto.Describe) 1546 b.byte(proto.Parse) 1547 b.byte(0) // unnamed portal 1548 1549 b.next(proto.Execute) 1550 b.byte(0) 1551 b.int32(0) 1552 1553 b.next(proto.Sync) 1554 return cn.send(b) 1555 } 1556 1557 func (cn *conn) processParameterStatus(r *readBuf) { 1558 switch r.string() { 1559 default: 1560 // ignore 1561 case "server_version": 1562 var major1, major2 int 1563 _, err := fmt.Sscanf(r.string(), "%d.%d", &major1, &major2) 1564 if err == nil { 1565 cn.parameterStatus.serverVersion = major1*10000 + major2*100 1566 } 1567 case "TimeZone": 1568 switch tz := r.string(); tz { 1569 case "UTC", "Etc/UTC", "Etc/Universal", "Etc/Zulu", "Etc/UCT": 1570 cn.parameterStatus.currentLocation = time.UTC 1571 default: 1572 var err error 1573 cn.parameterStatus.currentLocation, err = time.LoadLocation(tz) 1574 if err != nil { 1575 cn.parameterStatus.currentLocation = nil 1576 } 1577 } 1578 // Use sql.NullBool so we can distinguish between false and not sent. If 1579 // it's not sent we use a query to get the value – I don't know when these 1580 // parameters are not sent, but this is what libpq does. 1581 case "in_hot_standby": 1582 b, err := pqutil.ParseBool(r.string()) 1583 if err == nil { 1584 cn.parameterStatus.inHotStandby = sql.NullBool{Valid: true, Bool: b} 1585 } 1586 case "default_transaction_read_only": 1587 b, err := pqutil.ParseBool(r.string()) 1588 if err == nil { 1589 cn.parameterStatus.defaultTransactionReadOnly = sql.NullBool{Valid: true, Bool: b} 1590 } 1591 } 1592 } 1593 1594 func (cn *conn) processReadyForQuery(r *readBuf) { 1595 cn.txnStatus = transactionStatus(r.byte()) 1596 } 1597 1598 func (cn *conn) readReadyForQuery() error { 1599 t, r, err := cn.recv1() 1600 if err != nil { 1601 return err 1602 } 1603 switch t { 1604 case proto.ReadyForQuery: 1605 cn.processReadyForQuery(r) 1606 return nil 1607 case proto.ErrorResponse: 1608 err := parseError(r, "") 1609 cn.err.set(driver.ErrBadConn) 1610 return err 1611 default: 1612 cn.err.set(driver.ErrBadConn) 1613 return fmt.Errorf("pq: unexpected message %q; expected ReadyForQuery", t) 1614 } 1615 } 1616 1617 func (cn *conn) readParseResponse() error { 1618 t, r, err := cn.recv1() 1619 if err != nil { 1620 return err 1621 } 1622 switch t { 1623 case proto.ParseComplete: 1624 return nil 1625 case proto.ErrorResponse: 1626 err := parseError(r, "") 1627 _ = cn.readReadyForQuery() 1628 return err 1629 default: 1630 cn.err.set(driver.ErrBadConn) 1631 return fmt.Errorf("pq: unexpected Parse response %q", t) 1632 } 1633 } 1634 1635 func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc, _ error) { 1636 for { 1637 t, r, err := cn.recv1() 1638 if err != nil { 1639 return nil, nil, nil, err 1640 } 1641 switch t { 1642 case proto.ParameterDescription: 1643 nparams := r.int16() 1644 paramTyps = make([]oid.Oid, nparams) 1645 for i := range paramTyps { 1646 paramTyps[i] = r.oid() 1647 } 1648 case proto.NoData: 1649 return paramTyps, nil, nil, nil 1650 case proto.RowDescription: 1651 colNames, colTyps = parseStatementRowDescribe(r) 1652 return paramTyps, colNames, colTyps, nil 1653 case proto.ErrorResponse: 1654 err := parseError(r, "") 1655 _ = cn.readReadyForQuery() 1656 return nil, nil, nil, err 1657 default: 1658 cn.err.set(driver.ErrBadConn) 1659 return nil, nil, nil, fmt.Errorf("pq: unexpected Describe statement response %q", t) 1660 } 1661 } 1662 } 1663 1664 func (cn *conn) readPortalDescribeResponse() (rowsHeader, error) { 1665 t, r, err := cn.recv1() 1666 if err != nil { 1667 return rowsHeader{}, err 1668 } 1669 switch t { 1670 case proto.RowDescription: 1671 return parsePortalRowDescribe(r), nil 1672 case proto.NoData: 1673 return rowsHeader{}, nil 1674 case proto.ErrorResponse: 1675 err := parseError(r, "") 1676 _ = cn.readReadyForQuery() 1677 return rowsHeader{}, err 1678 default: 1679 cn.err.set(driver.ErrBadConn) 1680 return rowsHeader{}, fmt.Errorf("pq: unexpected Describe response %q", t) 1681 } 1682 } 1683 1684 func (cn *conn) readBindResponse() error { 1685 t, r, err := cn.recv1() 1686 if err != nil { 1687 return err 1688 } 1689 switch t { 1690 case proto.BindComplete: 1691 return nil 1692 case proto.ErrorResponse: 1693 err := parseError(r, "") 1694 _ = cn.readReadyForQuery() 1695 return err 1696 default: 1697 cn.err.set(driver.ErrBadConn) 1698 return fmt.Errorf("pq: unexpected Bind response %q", t) 1699 } 1700 } 1701 1702 func (cn *conn) postExecuteWorkaround() error { 1703 // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores 1704 // any errors from rows.Next, which masks errors that happened during the 1705 // execution of the query. To avoid the problem in common cases, we wait 1706 // here for one more message from the database. If it's not an error the 1707 // query will likely succeed (or perhaps has already, if it's a 1708 // CommandComplete), so we push the message into the conn struct; recv1 1709 // will return it as the next message for rows.Next or rows.Close. 1710 // However, if it's an error, we wait until ReadyForQuery and then return 1711 // the error to our caller. 1712 for { 1713 t, r, err := cn.recv1() 1714 if err != nil { 1715 return err 1716 } 1717 switch t { 1718 case proto.ErrorResponse: 1719 err := parseError(r, "") 1720 _ = cn.readReadyForQuery() 1721 return err 1722 case proto.CommandComplete, proto.DataRow, proto.EmptyQueryResponse: 1723 // the query didn't fail, but we can't process this message 1724 return cn.saveMessage(t, r) 1725 default: 1726 cn.err.set(driver.ErrBadConn) 1727 return fmt.Errorf("pq: unexpected message during extended query execution: %q", t) 1728 } 1729 } 1730 } 1731 1732 // Only for Exec(), since we ignore the returned data 1733 func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, resErr error) { 1734 for { 1735 t, r, err := cn.recv1() 1736 if err != nil { 1737 return nil, "", err 1738 } 1739 switch t { 1740 case proto.CommandComplete: 1741 if resErr != nil { 1742 cn.err.set(driver.ErrBadConn) 1743 return nil, "", fmt.Errorf("pq: unexpected CommandComplete after error %s", resErr) 1744 } 1745 res, commandTag, err = cn.parseComplete(r.string()) 1746 if err != nil { 1747 return nil, "", err 1748 } 1749 case proto.ReadyForQuery: 1750 cn.processReadyForQuery(r) 1751 if res == nil && resErr == nil { 1752 resErr = errUnexpectedReady 1753 } 1754 return res, commandTag, resErr 1755 case proto.ErrorResponse: 1756 resErr = parseError(r, "") 1757 case proto.RowDescription, proto.DataRow, proto.EmptyQueryResponse: 1758 if resErr != nil { 1759 cn.err.set(driver.ErrBadConn) 1760 return nil, "", fmt.Errorf("pq: unexpected %q after error %s", t, resErr) 1761 } 1762 if t == proto.EmptyQueryResponse { 1763 res = emptyRows 1764 } 1765 // ignore any results 1766 default: 1767 cn.err.set(driver.ErrBadConn) 1768 return nil, "", fmt.Errorf("pq: unknown %s response: %q", protocolState, t) 1769 } 1770 } 1771 } 1772 1773 func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { 1774 n := r.int16() 1775 colNames = make([]string, n) 1776 colTyps = make([]fieldDesc, n) 1777 for i := range colNames { 1778 colNames[i] = r.string() 1779 r.next(6) 1780 colTyps[i].OID = r.oid() 1781 colTyps[i].Len = r.int16() 1782 colTyps[i].Mod = r.int32() 1783 // format code not known when describing a statement; always 0 1784 r.next(2) 1785 } 1786 return 1787 } 1788 1789 func parsePortalRowDescribe(r *readBuf) rowsHeader { 1790 n := r.int16() 1791 colNames := make([]string, n) 1792 colFmts := make([]format, n) 1793 colTyps := make([]fieldDesc, n) 1794 for i := range colNames { 1795 colNames[i] = r.string() 1796 r.next(6) 1797 colTyps[i].OID = r.oid() 1798 colTyps[i].Len = r.int16() 1799 colTyps[i].Mod = r.int32() 1800 colFmts[i] = format(r.int16()) 1801 } 1802 return rowsHeader{ 1803 colNames: colNames, 1804 colFmts: colFmts, 1805 colTyps: colTyps, 1806 } 1807 } 1808 1809 func (cn *conn) ResetSession(ctx context.Context) error { 1810 // Ensure bad connections are reported: From database/sql/driver: 1811 // If a connection is never returned to the connection pool but immediately reused, then 1812 // ResetSession is called prior to reuse but IsValid is not called. 1813 return cn.err.get() 1814 } 1815 1816 func (cn *conn) IsValid() bool { 1817 return cn.err.get() == nil 1818 }