conn_go18.go (5475B)
1 package pq 2 3 import ( 4 "context" 5 "database/sql" 6 "database/sql/driver" 7 "fmt" 8 "io" 9 "time" 10 11 "github.com/lib/pq/internal/proto" 12 ) 13 14 const watchCancelDialContextTimeout = 10 * time.Second 15 16 // Implement the "QueryerContext" interface 17 func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 18 finish := cn.watchCancel(ctx) 19 r, err := cn.query(query, args) 20 if err != nil { 21 if finish != nil { 22 finish() 23 } 24 return nil, err 25 } 26 r.finish = finish 27 return r, nil 28 } 29 30 // Implement the "ExecerContext" interface 31 func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 32 list := make([]driver.Value, len(args)) 33 for i, nv := range args { 34 list[i] = nv.Value 35 } 36 37 if finish := cn.watchCancel(ctx); finish != nil { 38 defer finish() 39 } 40 41 return cn.Exec(query, list) 42 } 43 44 // Implement the "ConnPrepareContext" interface 45 func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { 46 if finish := cn.watchCancel(ctx); finish != nil { 47 defer finish() 48 } 49 return cn.Prepare(query) 50 } 51 52 // Implement the "ConnBeginTx" interface 53 func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { 54 var mode string 55 switch sql.IsolationLevel(opts.Isolation) { 56 case sql.LevelDefault: 57 // Don't touch mode: use the server's default 58 case sql.LevelReadUncommitted: 59 mode = " ISOLATION LEVEL READ UNCOMMITTED" 60 case sql.LevelReadCommitted: 61 mode = " ISOLATION LEVEL READ COMMITTED" 62 case sql.LevelRepeatableRead: 63 mode = " ISOLATION LEVEL REPEATABLE READ" 64 case sql.LevelSerializable: 65 mode = " ISOLATION LEVEL SERIALIZABLE" 66 default: 67 return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) 68 } 69 if opts.ReadOnly { 70 mode += " READ ONLY" 71 } else { 72 mode += " READ WRITE" 73 } 74 75 tx, err := cn.begin(mode) 76 if err != nil { 77 return nil, err 78 } 79 cn.txnFinish = cn.watchCancel(ctx) 80 return tx, nil 81 } 82 83 func (cn *conn) Ping(ctx context.Context) error { 84 if finish := cn.watchCancel(ctx); finish != nil { 85 defer finish() 86 } 87 rows, err := cn.simpleQuery(";") 88 if err != nil { 89 return driver.ErrBadConn 90 } 91 _ = rows.Close() 92 return nil 93 } 94 95 func (cn *conn) watchCancel(ctx context.Context) func() { 96 if done := ctx.Done(); done != nil { 97 finished := make(chan struct{}, 1) 98 go func() { 99 select { 100 case <-done: 101 select { 102 case finished <- struct{}{}: 103 default: 104 // We raced with the finish func, let the next query handle this with the 105 // context. 106 return 107 } 108 109 // Set the connection state to bad so it does not get reused. 110 cn.err.set(ctx.Err()) 111 112 // At this point the function level context is canceled, 113 // so it must not be used for the additional network 114 // request to cancel the query. 115 // Create a new context to pass into the dial. 116 ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) 117 defer cancel() 118 119 _ = cn.cancel(ctxCancel) 120 case <-finished: 121 } 122 }() 123 return func() { 124 select { 125 case <-finished: 126 cn.err.set(ctx.Err()) 127 _ = cn.Close() 128 case finished <- struct{}{}: 129 } 130 } 131 } 132 return nil 133 } 134 135 func (cn *conn) cancel(ctx context.Context) error { 136 // Use a copy since a new connection is created here. This is necessary 137 // because cancel is called from a goroutine in watchCancel. 138 cfg := cn.cfg.Clone() 139 140 c, err := dial(ctx, cn.dialer, cfg) 141 if err != nil { 142 return err 143 } 144 defer func() { _ = c.Close() }() 145 146 cn2 := conn{c: c} 147 err = cn2.ssl(cfg, cfg.SSLMode) 148 if err != nil { 149 return err 150 } 151 152 w := cn2.writeBuf(0) 153 w.int32(proto.CancelRequestCode) 154 w.int32(cn.pid) 155 w.bytes(cn.secretKey) 156 if err := cn2.sendStartupPacket(w); err != nil { 157 return err 158 } 159 160 // Read until EOF to ensure that the server received the cancel. 161 _, err = io.Copy(io.Discard, c) 162 return err 163 } 164 165 // Implement the "StmtQueryContext" interface 166 func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { 167 finish := st.watchCancel(ctx) 168 r, err := st.query(args) 169 if err != nil { 170 if finish != nil { 171 finish() 172 } 173 return nil, err 174 } 175 r.finish = finish 176 return r, nil 177 } 178 179 // Implement the "StmtExecContext" interface 180 func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { 181 if finish := st.watchCancel(ctx); finish != nil { 182 defer finish() 183 } 184 if err := st.cn.err.get(); err != nil { 185 return nil, err 186 } 187 188 err := st.exec(args) 189 if err != nil { 190 return nil, st.cn.handleError(err) 191 } 192 res, _, err := st.cn.readExecuteResponse("simple query") 193 return res, st.cn.handleError(err) 194 } 195 196 // watchCancel is implemented on stmt in order to not mark the parent conn as bad 197 func (st *stmt) watchCancel(ctx context.Context) func() { 198 if done := ctx.Done(); done != nil { 199 finished := make(chan struct{}) 200 go func() { 201 select { 202 case <-done: 203 // At this point the function level context is canceled, so it 204 // must not be used for the additional network request to cancel 205 // the query. Create a new context to pass into the dial. 206 ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout) 207 defer cancel() 208 209 _ = st.cancel(ctxCancel) 210 finished <- struct{}{} 211 case <-finished: 212 } 213 }() 214 return func() { 215 select { 216 case <-finished: 217 case finished <- struct{}{}: 218 } 219 } 220 } 221 return nil 222 } 223 224 func (st *stmt) cancel(ctx context.Context) error { 225 return st.cn.cancel(ctx) 226 }