stmt.go (2882B)
1 package pq 2 3 import ( 4 "context" 5 "database/sql/driver" 6 "fmt" 7 "os" 8 9 "github.com/lib/pq/internal/proto" 10 "github.com/lib/pq/oid" 11 ) 12 13 type stmt struct { 14 cn *conn 15 name string 16 rowsHeader 17 colFmtData []byte 18 paramTyps []oid.Oid 19 closed bool 20 } 21 22 func (st *stmt) Close() error { 23 if st.closed { 24 return nil 25 } 26 if err := st.cn.err.get(); err != nil { 27 return err 28 } 29 30 w := st.cn.writeBuf(proto.Close) 31 w.byte(proto.Sync) 32 w.string(st.name) 33 err := st.cn.send(w) 34 if err != nil { 35 return st.cn.handleError(err) 36 } 37 err = st.cn.send(st.cn.writeBuf(proto.Sync)) 38 if err != nil { 39 return st.cn.handleError(err) 40 } 41 42 t, _, err := st.cn.recv1() 43 if err != nil { 44 return st.cn.handleError(err) 45 } 46 if t != proto.CloseComplete { 47 st.cn.err.set(driver.ErrBadConn) 48 return fmt.Errorf("pq: unexpected close response: %q", t) 49 } 50 st.closed = true 51 52 t, r, err := st.cn.recv1() 53 if err != nil { 54 return st.cn.handleError(err) 55 } 56 if t != proto.ReadyForQuery { 57 st.cn.err.set(driver.ErrBadConn) 58 return fmt.Errorf("pq: expected ready for query, but got: %q", t) 59 } 60 st.cn.processReadyForQuery(r) 61 62 return nil 63 } 64 65 func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { 66 return st.query(toNamedValue(v)) 67 } 68 69 func (st *stmt) query(v []driver.NamedValue) (*rows, error) { 70 if err := st.cn.err.get(); err != nil { 71 return nil, err 72 } 73 74 err := st.exec(v) 75 if err != nil { 76 return nil, st.cn.handleError(err) 77 } 78 return &rows{ 79 cn: st.cn, 80 rowsHeader: st.rowsHeader, 81 }, nil 82 } 83 84 func (st *stmt) Exec(v []driver.Value) (driver.Result, error) { 85 return st.ExecContext(context.Background(), toNamedValue(v)) 86 } 87 88 func (st *stmt) exec(v []driver.NamedValue) error { 89 if debugProto { 90 fmt.Fprintf(os.Stderr, " START stmt.exec\n") 91 defer fmt.Fprintf(os.Stderr, " END stmt.exec\n") 92 } 93 if len(v) >= 65536 { 94 return fmt.Errorf("pq: got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) 95 } 96 if len(v) != len(st.paramTyps) { 97 return fmt.Errorf("pq: got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) 98 } 99 100 cn := st.cn 101 w := cn.writeBuf(proto.Bind) 102 w.byte(0) // unnamed portal 103 w.string(st.name) 104 105 if cn.cfg.BinaryParameters { 106 err := cn.sendBinaryParameters(w, v) 107 if err != nil { 108 return err 109 } 110 } else { 111 w.int16(0) 112 w.int16(len(v)) 113 for i, x := range v { 114 if x.Value == nil { 115 w.int32(-1) 116 } else { 117 b, err := encode(x.Value, st.paramTyps[i]) 118 if err != nil { 119 return err 120 } 121 if b == nil { 122 w.int32(-1) 123 } else { 124 w.int32(len(b)) 125 w.bytes(b) 126 } 127 } 128 } 129 } 130 w.bytes(st.colFmtData) 131 132 w.next(proto.Execute) 133 w.byte(0) 134 w.int32(0) 135 136 w.next(proto.Sync) 137 err := cn.send(w) 138 if err != nil { 139 return err 140 } 141 err = cn.readBindResponse() 142 if err != nil { 143 return err 144 } 145 return cn.postExecuteWorkaround() 146 } 147 148 func (st *stmt) NumInput() int { 149 return len(st.paramTyps) 150 }