taldir

Directory service to resolve wallet mailboxes by messenger addresses
Log | Files | Refs | Submodules | README | LICENSE

copy.go (7531B)


      1 package pq
      2 
      3 import (
      4 	"context"
      5 	"database/sql/driver"
      6 	"encoding/binary"
      7 	"errors"
      8 	"fmt"
      9 	"os"
     10 	"sync"
     11 
     12 	"github.com/lib/pq/internal/proto"
     13 )
     14 
     15 var (
     16 	errCopyInClosed               = errors.New("pq: copyin statement has already been closed")
     17 	errBinaryCopyNotSupported     = errors.New("pq: only text format supported for COPY")
     18 	errCopyToNotSupported         = errors.New("pq: COPY TO is not supported")
     19 	errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
     20 )
     21 
     22 type copyin struct {
     23 	cn      *conn
     24 	buffer  []byte
     25 	rowData chan []byte
     26 	done    chan bool
     27 	closed  bool
     28 	mu      struct {
     29 		sync.Mutex
     30 		err error
     31 		driver.Result
     32 	}
     33 }
     34 
     35 const (
     36 	ciBufferSize = 64 * 1024
     37 	// flush buffer before the buffer is filled up and needs reallocation
     38 	ciBufferFlushSize = 63 * 1024
     39 )
     40 
     41 func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, resErr error) {
     42 	if !cn.isInTransaction() {
     43 		return nil, errCopyNotSupportedOutsideTxn
     44 	}
     45 
     46 	ci := &copyin{
     47 		cn:      cn,
     48 		buffer:  make([]byte, 0, ciBufferSize),
     49 		rowData: make(chan []byte),
     50 		done:    make(chan bool, 1),
     51 	}
     52 	// add CopyData identifier + 4 bytes for message length
     53 	ci.buffer = append(ci.buffer, byte(proto.CopyDataRequest), 0, 0, 0, 0)
     54 
     55 	b := cn.writeBuf(proto.Query)
     56 	b.string(q)
     57 	err := cn.send(b)
     58 	if err != nil {
     59 		return nil, err
     60 	}
     61 
     62 awaitCopyInResponse:
     63 	for {
     64 		t, r, err := cn.recv1()
     65 		if err != nil {
     66 			return nil, err
     67 		}
     68 		switch t {
     69 		case proto.CopyInResponse:
     70 			if r.byte() != 0 {
     71 				resErr = errBinaryCopyNotSupported
     72 				break awaitCopyInResponse
     73 			}
     74 			go ci.resploop()
     75 			return ci, nil
     76 		case proto.CopyOutResponse:
     77 			resErr = errCopyToNotSupported
     78 			break awaitCopyInResponse
     79 		case proto.ErrorResponse:
     80 			resErr = parseError(r, q)
     81 		case proto.ReadyForQuery:
     82 			if resErr == nil {
     83 				ci.setBad(driver.ErrBadConn)
     84 				return nil, fmt.Errorf("pq: unexpected ReadyForQuery in response to COPY")
     85 			}
     86 			cn.processReadyForQuery(r)
     87 			return nil, resErr
     88 		default:
     89 			ci.setBad(driver.ErrBadConn)
     90 			return nil, fmt.Errorf("pq: unknown response for copy query: %q", t)
     91 		}
     92 	}
     93 
     94 	// something went wrong, abort COPY before we return
     95 	b = cn.writeBuf(proto.CopyFail)
     96 	b.string(resErr.Error())
     97 	err = cn.send(b)
     98 	if err != nil {
     99 		return nil, err
    100 	}
    101 
    102 	for {
    103 		t, r, err := cn.recv1()
    104 		if err != nil {
    105 			return nil, err
    106 		}
    107 
    108 		switch t {
    109 		case proto.CopyDoneResponse, proto.CommandComplete, proto.ErrorResponse:
    110 		case proto.ReadyForQuery:
    111 			// correctly aborted, we're done
    112 			cn.processReadyForQuery(r)
    113 			return nil, resErr
    114 		default:
    115 			ci.setBad(driver.ErrBadConn)
    116 			return nil, fmt.Errorf("pq: unknown response for CopyFail: %q", t)
    117 		}
    118 	}
    119 }
    120 
    121 func (ci *copyin) flush(buf []byte) error {
    122 	if len(buf)-1 > proto.MaxUint32 {
    123 		return errors.New("pq: too many columns")
    124 	}
    125 	if debugProto {
    126 		fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d  %q\n", proto.RequestCode(buf[0]), len(buf)-5, buf[5:])
    127 	}
    128 	binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1)) // Set message length (without message identifier).
    129 	_, err := ci.cn.c.Write(buf)
    130 	return err
    131 }
    132 
    133 func (ci *copyin) resploop() {
    134 	for {
    135 		var r readBuf
    136 		t, err := ci.cn.recvMessage(&r)
    137 		if err != nil {
    138 			ci.setBad(driver.ErrBadConn)
    139 			ci.setError(err)
    140 			ci.done <- true
    141 			return
    142 		}
    143 		switch t {
    144 		case proto.CommandComplete:
    145 			// complete
    146 			res, _, err := ci.cn.parseComplete(r.string())
    147 			if err != nil {
    148 				panic(err)
    149 			}
    150 			ci.setResult(res)
    151 		case proto.NoticeResponse:
    152 			if n := ci.cn.noticeHandler; n != nil {
    153 				n(parseError(&r, ""))
    154 			}
    155 		case proto.ReadyForQuery:
    156 			ci.cn.processReadyForQuery(&r)
    157 			ci.done <- true
    158 			return
    159 		case proto.ErrorResponse:
    160 			err := parseError(&r, "")
    161 			ci.setError(err)
    162 		default:
    163 			ci.setBad(driver.ErrBadConn)
    164 			ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
    165 			ci.done <- true
    166 			return
    167 		}
    168 	}
    169 }
    170 
    171 func (ci *copyin) setBad(err error) {
    172 	ci.cn.err.set(err)
    173 }
    174 
    175 func (ci *copyin) getBad() error {
    176 	return ci.cn.err.get()
    177 }
    178 
    179 func (ci *copyin) err() error {
    180 	ci.mu.Lock()
    181 	err := ci.mu.err
    182 	ci.mu.Unlock()
    183 	return err
    184 }
    185 
    186 // setError() sets ci.err if one has not been set already.  Caller must not be
    187 // holding ci.Mutex.
    188 func (ci *copyin) setError(err error) {
    189 	ci.mu.Lock()
    190 	if ci.mu.err == nil {
    191 		ci.mu.err = err
    192 	}
    193 	ci.mu.Unlock()
    194 }
    195 
    196 func (ci *copyin) setResult(result driver.Result) {
    197 	ci.mu.Lock()
    198 	ci.mu.Result = result
    199 	ci.mu.Unlock()
    200 }
    201 
    202 func (ci *copyin) getResult() driver.Result {
    203 	ci.mu.Lock()
    204 	result := ci.mu.Result
    205 	ci.mu.Unlock()
    206 	if result == nil {
    207 		return driver.RowsAffected(0)
    208 	}
    209 	return result
    210 }
    211 
    212 func (ci *copyin) NumInput() int {
    213 	return -1
    214 }
    215 
    216 func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
    217 	return nil, ErrNotSupported
    218 }
    219 
    220 // Exec inserts values into the COPY stream. The insert is asynchronous
    221 // and Exec can return errors from previous Exec calls to the same
    222 // COPY stmt.
    223 //
    224 // You need to call Exec(nil) to sync the COPY stream and to get any
    225 // errors from pending data, since Stmt.Close() doesn't return errors
    226 // to the user.
    227 func (ci *copyin) Exec(v []driver.Value) (driver.Result, error) {
    228 	if ci.closed {
    229 		return nil, errCopyInClosed
    230 	}
    231 	if err := ci.getBad(); err != nil {
    232 		return nil, err
    233 	}
    234 	if err := ci.err(); err != nil {
    235 		return nil, err
    236 	}
    237 
    238 	if len(v) == 0 {
    239 		if err := ci.Close(); err != nil {
    240 			return driver.RowsAffected(0), err
    241 		}
    242 		return ci.getResult(), nil
    243 	}
    244 
    245 	var (
    246 		numValues = len(v)
    247 		err       error
    248 	)
    249 	for i, value := range v {
    250 		ci.buffer, err = appendEncodedText(ci.buffer, value)
    251 		if err != nil {
    252 			return nil, ci.cn.handleError(err)
    253 		}
    254 		if i < numValues-1 {
    255 			ci.buffer = append(ci.buffer, '\t')
    256 		}
    257 	}
    258 
    259 	ci.buffer = append(ci.buffer, '\n')
    260 
    261 	if len(ci.buffer) > ciBufferFlushSize {
    262 		err := ci.flush(ci.buffer)
    263 		if err != nil {
    264 			return nil, ci.cn.handleError(err)
    265 		}
    266 		// reset buffer, keep bytes for message identifier and length
    267 		ci.buffer = ci.buffer[:5]
    268 	}
    269 
    270 	return driver.RowsAffected(0), nil
    271 }
    272 
    273 // CopyData inserts a raw string into the COPY stream. The insert is
    274 // asynchronous and CopyData can return errors from previous CopyData calls to
    275 // the same COPY stmt.
    276 //
    277 // You need to call Exec(nil) to sync the COPY stream and to get any
    278 // errors from pending data, since Stmt.Close() doesn't return errors
    279 // to the user.
    280 func (ci *copyin) CopyData(ctx context.Context, line string) (driver.Result, error) {
    281 	if ci.closed {
    282 		return nil, errCopyInClosed
    283 	}
    284 	if finish := ci.cn.watchCancel(ctx); finish != nil {
    285 		defer finish()
    286 	}
    287 	if err := ci.getBad(); err != nil {
    288 		return nil, err
    289 	}
    290 	if err := ci.err(); err != nil {
    291 		return nil, err
    292 	}
    293 
    294 	ci.buffer = append(ci.buffer, []byte(line)...)
    295 	ci.buffer = append(ci.buffer, '\n')
    296 
    297 	if len(ci.buffer) > ciBufferFlushSize {
    298 		err := ci.flush(ci.buffer)
    299 		if err != nil {
    300 			return nil, ci.cn.handleError(err)
    301 		}
    302 
    303 		// reset buffer, keep bytes for message identifier and length
    304 		ci.buffer = ci.buffer[:5]
    305 	}
    306 
    307 	return driver.RowsAffected(0), nil
    308 }
    309 
    310 func (ci *copyin) Close() error {
    311 	if ci.closed { // Don't do anything, we're already closed
    312 		return nil
    313 	}
    314 	ci.closed = true
    315 
    316 	if err := ci.getBad(); err != nil {
    317 		return err
    318 	}
    319 
    320 	if len(ci.buffer) > 0 {
    321 		err := ci.flush(ci.buffer)
    322 		if err != nil {
    323 			return ci.cn.handleError(err)
    324 		}
    325 	}
    326 	// Avoid touching the scratch buffer as resploop could be using it.
    327 	err := ci.cn.sendSimpleMessage(proto.CopyDoneRequest)
    328 	if err != nil {
    329 		return ci.cn.handleError(err)
    330 	}
    331 
    332 	<-ci.done
    333 	ci.cn.inProgress.Store(false)
    334 
    335 	if err := ci.err(); err != nil {
    336 		return err
    337 	}
    338 	return nil
    339 }