taldir

Directory service to resolve wallet mailboxes by messenger addresses
Log | Files | Refs | Submodules | README | LICENSE

conn_go18.go (5475B)


      1 package pq
      2 
      3 import (
      4 	"context"
      5 	"database/sql"
      6 	"database/sql/driver"
      7 	"fmt"
      8 	"io"
      9 	"time"
     10 
     11 	"github.com/lib/pq/internal/proto"
     12 )
     13 
     14 const watchCancelDialContextTimeout = 10 * time.Second
     15 
     16 // Implement the "QueryerContext" interface
     17 func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
     18 	finish := cn.watchCancel(ctx)
     19 	r, err := cn.query(query, args)
     20 	if err != nil {
     21 		if finish != nil {
     22 			finish()
     23 		}
     24 		return nil, err
     25 	}
     26 	r.finish = finish
     27 	return r, nil
     28 }
     29 
     30 // Implement the "ExecerContext" interface
     31 func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
     32 	list := make([]driver.Value, len(args))
     33 	for i, nv := range args {
     34 		list[i] = nv.Value
     35 	}
     36 
     37 	if finish := cn.watchCancel(ctx); finish != nil {
     38 		defer finish()
     39 	}
     40 
     41 	return cn.Exec(query, list)
     42 }
     43 
     44 // Implement the "ConnPrepareContext" interface
     45 func (cn *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
     46 	if finish := cn.watchCancel(ctx); finish != nil {
     47 		defer finish()
     48 	}
     49 	return cn.Prepare(query)
     50 }
     51 
     52 // Implement the "ConnBeginTx" interface
     53 func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
     54 	var mode string
     55 	switch sql.IsolationLevel(opts.Isolation) {
     56 	case sql.LevelDefault:
     57 		// Don't touch mode: use the server's default
     58 	case sql.LevelReadUncommitted:
     59 		mode = " ISOLATION LEVEL READ UNCOMMITTED"
     60 	case sql.LevelReadCommitted:
     61 		mode = " ISOLATION LEVEL READ COMMITTED"
     62 	case sql.LevelRepeatableRead:
     63 		mode = " ISOLATION LEVEL REPEATABLE READ"
     64 	case sql.LevelSerializable:
     65 		mode = " ISOLATION LEVEL SERIALIZABLE"
     66 	default:
     67 		return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
     68 	}
     69 	if opts.ReadOnly {
     70 		mode += " READ ONLY"
     71 	} else {
     72 		mode += " READ WRITE"
     73 	}
     74 
     75 	tx, err := cn.begin(mode)
     76 	if err != nil {
     77 		return nil, err
     78 	}
     79 	cn.txnFinish = cn.watchCancel(ctx)
     80 	return tx, nil
     81 }
     82 
     83 func (cn *conn) Ping(ctx context.Context) error {
     84 	if finish := cn.watchCancel(ctx); finish != nil {
     85 		defer finish()
     86 	}
     87 	rows, err := cn.simpleQuery(";")
     88 	if err != nil {
     89 		return driver.ErrBadConn
     90 	}
     91 	_ = rows.Close()
     92 	return nil
     93 }
     94 
     95 func (cn *conn) watchCancel(ctx context.Context) func() {
     96 	if done := ctx.Done(); done != nil {
     97 		finished := make(chan struct{}, 1)
     98 		go func() {
     99 			select {
    100 			case <-done:
    101 				select {
    102 				case finished <- struct{}{}:
    103 				default:
    104 					// We raced with the finish func, let the next query handle this with the
    105 					// context.
    106 					return
    107 				}
    108 
    109 				// Set the connection state to bad so it does not get reused.
    110 				cn.err.set(ctx.Err())
    111 
    112 				// At this point the function level context is canceled,
    113 				// so it must not be used for the additional network
    114 				// request to cancel the query.
    115 				// Create a new context to pass into the dial.
    116 				ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
    117 				defer cancel()
    118 
    119 				_ = cn.cancel(ctxCancel)
    120 			case <-finished:
    121 			}
    122 		}()
    123 		return func() {
    124 			select {
    125 			case <-finished:
    126 				cn.err.set(ctx.Err())
    127 				_ = cn.Close()
    128 			case finished <- struct{}{}:
    129 			}
    130 		}
    131 	}
    132 	return nil
    133 }
    134 
    135 func (cn *conn) cancel(ctx context.Context) error {
    136 	// Use a copy since a new connection is created here. This is necessary
    137 	// because cancel is called from a goroutine in watchCancel.
    138 	cfg := cn.cfg.Clone()
    139 
    140 	c, err := dial(ctx, cn.dialer, cfg)
    141 	if err != nil {
    142 		return err
    143 	}
    144 	defer func() { _ = c.Close() }()
    145 
    146 	cn2 := conn{c: c}
    147 	err = cn2.ssl(cfg, cfg.SSLMode)
    148 	if err != nil {
    149 		return err
    150 	}
    151 
    152 	w := cn2.writeBuf(0)
    153 	w.int32(proto.CancelRequestCode)
    154 	w.int32(cn.pid)
    155 	w.bytes(cn.secretKey)
    156 	if err := cn2.sendStartupPacket(w); err != nil {
    157 		return err
    158 	}
    159 
    160 	// Read until EOF to ensure that the server received the cancel.
    161 	_, err = io.Copy(io.Discard, c)
    162 	return err
    163 }
    164 
    165 // Implement the "StmtQueryContext" interface
    166 func (st *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
    167 	finish := st.watchCancel(ctx)
    168 	r, err := st.query(args)
    169 	if err != nil {
    170 		if finish != nil {
    171 			finish()
    172 		}
    173 		return nil, err
    174 	}
    175 	r.finish = finish
    176 	return r, nil
    177 }
    178 
    179 // Implement the "StmtExecContext" interface
    180 func (st *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
    181 	if finish := st.watchCancel(ctx); finish != nil {
    182 		defer finish()
    183 	}
    184 	if err := st.cn.err.get(); err != nil {
    185 		return nil, err
    186 	}
    187 
    188 	err := st.exec(args)
    189 	if err != nil {
    190 		return nil, st.cn.handleError(err)
    191 	}
    192 	res, _, err := st.cn.readExecuteResponse("simple query")
    193 	return res, st.cn.handleError(err)
    194 }
    195 
    196 // watchCancel is implemented on stmt in order to not mark the parent conn as bad
    197 func (st *stmt) watchCancel(ctx context.Context) func() {
    198 	if done := ctx.Done(); done != nil {
    199 		finished := make(chan struct{})
    200 		go func() {
    201 			select {
    202 			case <-done:
    203 				// At this point the function level context is canceled, so it
    204 				// must not be used for the additional network request to cancel
    205 				// the query. Create a new context to pass into the dial.
    206 				ctxCancel, cancel := context.WithTimeout(context.Background(), watchCancelDialContextTimeout)
    207 				defer cancel()
    208 
    209 				_ = st.cancel(ctxCancel)
    210 				finished <- struct{}{}
    211 			case <-finished:
    212 			}
    213 		}()
    214 		return func() {
    215 			select {
    216 			case <-finished:
    217 			case finished <- struct{}{}:
    218 			}
    219 		}
    220 	}
    221 	return nil
    222 }
    223 
    224 func (st *stmt) cancel(ctx context.Context) error {
    225 	return st.cn.cancel(ctx)
    226 }