copy.go (7531B)
1 package pq 2 3 import ( 4 "context" 5 "database/sql/driver" 6 "encoding/binary" 7 "errors" 8 "fmt" 9 "os" 10 "sync" 11 12 "github.com/lib/pq/internal/proto" 13 ) 14 15 var ( 16 errCopyInClosed = errors.New("pq: copyin statement has already been closed") 17 errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") 18 errCopyToNotSupported = errors.New("pq: COPY TO is not supported") 19 errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") 20 ) 21 22 type copyin struct { 23 cn *conn 24 buffer []byte 25 rowData chan []byte 26 done chan bool 27 closed bool 28 mu struct { 29 sync.Mutex 30 err error 31 driver.Result 32 } 33 } 34 35 const ( 36 ciBufferSize = 64 * 1024 37 // flush buffer before the buffer is filled up and needs reallocation 38 ciBufferFlushSize = 63 * 1024 39 ) 40 41 func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, resErr error) { 42 if !cn.isInTransaction() { 43 return nil, errCopyNotSupportedOutsideTxn 44 } 45 46 ci := ©in{ 47 cn: cn, 48 buffer: make([]byte, 0, ciBufferSize), 49 rowData: make(chan []byte), 50 done: make(chan bool, 1), 51 } 52 // add CopyData identifier + 4 bytes for message length 53 ci.buffer = append(ci.buffer, byte(proto.CopyDataRequest), 0, 0, 0, 0) 54 55 b := cn.writeBuf(proto.Query) 56 b.string(q) 57 err := cn.send(b) 58 if err != nil { 59 return nil, err 60 } 61 62 awaitCopyInResponse: 63 for { 64 t, r, err := cn.recv1() 65 if err != nil { 66 return nil, err 67 } 68 switch t { 69 case proto.CopyInResponse: 70 if r.byte() != 0 { 71 resErr = errBinaryCopyNotSupported 72 break awaitCopyInResponse 73 } 74 go ci.resploop() 75 return ci, nil 76 case proto.CopyOutResponse: 77 resErr = errCopyToNotSupported 78 break awaitCopyInResponse 79 case proto.ErrorResponse: 80 resErr = parseError(r, q) 81 case proto.ReadyForQuery: 82 if resErr == nil { 83 ci.setBad(driver.ErrBadConn) 84 return nil, fmt.Errorf("pq: unexpected ReadyForQuery in response to COPY") 85 } 86 cn.processReadyForQuery(r) 87 return nil, resErr 88 default: 89 ci.setBad(driver.ErrBadConn) 90 return nil, fmt.Errorf("pq: unknown response for copy query: %q", t) 91 } 92 } 93 94 // something went wrong, abort COPY before we return 95 b = cn.writeBuf(proto.CopyFail) 96 b.string(resErr.Error()) 97 err = cn.send(b) 98 if err != nil { 99 return nil, err 100 } 101 102 for { 103 t, r, err := cn.recv1() 104 if err != nil { 105 return nil, err 106 } 107 108 switch t { 109 case proto.CopyDoneResponse, proto.CommandComplete, proto.ErrorResponse: 110 case proto.ReadyForQuery: 111 // correctly aborted, we're done 112 cn.processReadyForQuery(r) 113 return nil, resErr 114 default: 115 ci.setBad(driver.ErrBadConn) 116 return nil, fmt.Errorf("pq: unknown response for CopyFail: %q", t) 117 } 118 } 119 } 120 121 func (ci *copyin) flush(buf []byte) error { 122 if len(buf)-1 > proto.MaxUint32 { 123 return errors.New("pq: too many columns") 124 } 125 if debugProto { 126 fmt.Fprintf(os.Stderr, "CLIENT → %-20s %5d %q\n", proto.RequestCode(buf[0]), len(buf)-5, buf[5:]) 127 } 128 binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1)) // Set message length (without message identifier). 129 _, err := ci.cn.c.Write(buf) 130 return err 131 } 132 133 func (ci *copyin) resploop() { 134 for { 135 var r readBuf 136 t, err := ci.cn.recvMessage(&r) 137 if err != nil { 138 ci.setBad(driver.ErrBadConn) 139 ci.setError(err) 140 ci.done <- true 141 return 142 } 143 switch t { 144 case proto.CommandComplete: 145 // complete 146 res, _, err := ci.cn.parseComplete(r.string()) 147 if err != nil { 148 panic(err) 149 } 150 ci.setResult(res) 151 case proto.NoticeResponse: 152 if n := ci.cn.noticeHandler; n != nil { 153 n(parseError(&r, "")) 154 } 155 case proto.ReadyForQuery: 156 ci.cn.processReadyForQuery(&r) 157 ci.done <- true 158 return 159 case proto.ErrorResponse: 160 err := parseError(&r, "") 161 ci.setError(err) 162 default: 163 ci.setBad(driver.ErrBadConn) 164 ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) 165 ci.done <- true 166 return 167 } 168 } 169 } 170 171 func (ci *copyin) setBad(err error) { 172 ci.cn.err.set(err) 173 } 174 175 func (ci *copyin) getBad() error { 176 return ci.cn.err.get() 177 } 178 179 func (ci *copyin) err() error { 180 ci.mu.Lock() 181 err := ci.mu.err 182 ci.mu.Unlock() 183 return err 184 } 185 186 // setError() sets ci.err if one has not been set already. Caller must not be 187 // holding ci.Mutex. 188 func (ci *copyin) setError(err error) { 189 ci.mu.Lock() 190 if ci.mu.err == nil { 191 ci.mu.err = err 192 } 193 ci.mu.Unlock() 194 } 195 196 func (ci *copyin) setResult(result driver.Result) { 197 ci.mu.Lock() 198 ci.mu.Result = result 199 ci.mu.Unlock() 200 } 201 202 func (ci *copyin) getResult() driver.Result { 203 ci.mu.Lock() 204 result := ci.mu.Result 205 ci.mu.Unlock() 206 if result == nil { 207 return driver.RowsAffected(0) 208 } 209 return result 210 } 211 212 func (ci *copyin) NumInput() int { 213 return -1 214 } 215 216 func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) { 217 return nil, ErrNotSupported 218 } 219 220 // Exec inserts values into the COPY stream. The insert is asynchronous 221 // and Exec can return errors from previous Exec calls to the same 222 // COPY stmt. 223 // 224 // You need to call Exec(nil) to sync the COPY stream and to get any 225 // errors from pending data, since Stmt.Close() doesn't return errors 226 // to the user. 227 func (ci *copyin) Exec(v []driver.Value) (driver.Result, error) { 228 if ci.closed { 229 return nil, errCopyInClosed 230 } 231 if err := ci.getBad(); err != nil { 232 return nil, err 233 } 234 if err := ci.err(); err != nil { 235 return nil, err 236 } 237 238 if len(v) == 0 { 239 if err := ci.Close(); err != nil { 240 return driver.RowsAffected(0), err 241 } 242 return ci.getResult(), nil 243 } 244 245 var ( 246 numValues = len(v) 247 err error 248 ) 249 for i, value := range v { 250 ci.buffer, err = appendEncodedText(ci.buffer, value) 251 if err != nil { 252 return nil, ci.cn.handleError(err) 253 } 254 if i < numValues-1 { 255 ci.buffer = append(ci.buffer, '\t') 256 } 257 } 258 259 ci.buffer = append(ci.buffer, '\n') 260 261 if len(ci.buffer) > ciBufferFlushSize { 262 err := ci.flush(ci.buffer) 263 if err != nil { 264 return nil, ci.cn.handleError(err) 265 } 266 // reset buffer, keep bytes for message identifier and length 267 ci.buffer = ci.buffer[:5] 268 } 269 270 return driver.RowsAffected(0), nil 271 } 272 273 // CopyData inserts a raw string into the COPY stream. The insert is 274 // asynchronous and CopyData can return errors from previous CopyData calls to 275 // the same COPY stmt. 276 // 277 // You need to call Exec(nil) to sync the COPY stream and to get any 278 // errors from pending data, since Stmt.Close() doesn't return errors 279 // to the user. 280 func (ci *copyin) CopyData(ctx context.Context, line string) (driver.Result, error) { 281 if ci.closed { 282 return nil, errCopyInClosed 283 } 284 if finish := ci.cn.watchCancel(ctx); finish != nil { 285 defer finish() 286 } 287 if err := ci.getBad(); err != nil { 288 return nil, err 289 } 290 if err := ci.err(); err != nil { 291 return nil, err 292 } 293 294 ci.buffer = append(ci.buffer, []byte(line)...) 295 ci.buffer = append(ci.buffer, '\n') 296 297 if len(ci.buffer) > ciBufferFlushSize { 298 err := ci.flush(ci.buffer) 299 if err != nil { 300 return nil, ci.cn.handleError(err) 301 } 302 303 // reset buffer, keep bytes for message identifier and length 304 ci.buffer = ci.buffer[:5] 305 } 306 307 return driver.RowsAffected(0), nil 308 } 309 310 func (ci *copyin) Close() error { 311 if ci.closed { // Don't do anything, we're already closed 312 return nil 313 } 314 ci.closed = true 315 316 if err := ci.getBad(); err != nil { 317 return err 318 } 319 320 if len(ci.buffer) > 0 { 321 err := ci.flush(ci.buffer) 322 if err != nil { 323 return ci.cn.handleError(err) 324 } 325 } 326 // Avoid touching the scratch buffer as resploop could be using it. 327 err := ci.cn.sendSimpleMessage(proto.CopyDoneRequest) 328 if err != nil { 329 return ci.cn.handleError(err) 330 } 331 332 <-ci.done 333 ci.cn.inProgress.Store(false) 334 335 if err := ci.err(); err != nil { 336 return err 337 } 338 return nil 339 }