encode.go (10728B)
1 package pq 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "encoding/hex" 7 "errors" 8 "fmt" 9 "strconv" 10 "strings" 11 "time" 12 13 "github.com/lib/pq/internal/pqtime" 14 "github.com/lib/pq/oid" 15 ) 16 17 func binaryEncode(x any) ([]byte, error) { 18 switch v := x.(type) { 19 case []byte: 20 return v, nil 21 default: 22 return encode(x, oid.T_unknown) 23 } 24 } 25 26 func encode(x any, pgtypOid oid.Oid) ([]byte, error) { 27 switch v := x.(type) { 28 case int64: 29 return strconv.AppendInt(nil, v, 10), nil 30 case float64: 31 return strconv.AppendFloat(nil, v, 'f', -1, 64), nil 32 case []byte: 33 if v == nil { 34 return nil, nil 35 } 36 if pgtypOid == oid.T_bytea { 37 return encodeBytea(v), nil 38 } 39 return v, nil 40 case string: 41 if pgtypOid == oid.T_bytea { 42 return encodeBytea([]byte(v)), nil 43 } 44 return []byte(v), nil 45 case bool: 46 return strconv.AppendBool(nil, v), nil 47 case time.Time: 48 return formatTS(v), nil 49 default: 50 return nil, fmt.Errorf("pq: encode: unknown type for %T", v) 51 } 52 } 53 54 func decode(ps *parameterStatus, s []byte, typ oid.Oid, f format) (any, error) { 55 switch f { 56 case formatBinary: 57 return binaryDecode(s, typ) 58 case formatText: 59 return textDecode(ps, s, typ) 60 default: 61 panic("unreachable") 62 } 63 } 64 65 func binaryDecode(s []byte, typ oid.Oid) (any, error) { 66 switch typ { 67 case oid.T_bytea: 68 return s, nil 69 case oid.T_int8: 70 return int64(binary.BigEndian.Uint64(s)), nil 71 case oid.T_int4: 72 return int64(int32(binary.BigEndian.Uint32(s))), nil 73 case oid.T_int2: 74 return int64(int16(binary.BigEndian.Uint16(s))), nil 75 case oid.T_uuid: 76 return decodeUUIDBinary(s) 77 default: 78 return nil, fmt.Errorf("pq: don't know how to decode binary parameter of type %d", uint32(typ)) 79 } 80 81 } 82 83 // decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. 84 func decodeUUIDBinary(src []byte) ([]byte, error) { 85 if len(src) != 16 { 86 return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) 87 } 88 89 dst := make([]byte, 36) 90 dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' 91 hex.Encode(dst[0:], src[0:4]) 92 hex.Encode(dst[9:], src[4:6]) 93 hex.Encode(dst[14:], src[6:8]) 94 hex.Encode(dst[19:], src[8:10]) 95 hex.Encode(dst[24:], src[10:16]) 96 return dst, nil 97 } 98 99 func textDecode(ps *parameterStatus, s []byte, typ oid.Oid) (any, error) { 100 switch typ { 101 case oid.T_char, oid.T_bpchar, oid.T_varchar, oid.T_text: 102 return string(s), nil 103 case oid.T_bytea: 104 b, err := parseBytea(s) 105 if err != nil { 106 err = errors.New("pq: " + err.Error()) 107 } 108 return b, err 109 case oid.T_timestamptz: 110 return parseTS(ps.currentLocation, string(s)) 111 case oid.T_timestamp, oid.T_date: 112 return parseTS(nil, string(s)) 113 case oid.T_time: 114 return parseTime(typ, s) 115 case oid.T_timetz: 116 return parseTime(typ, s) 117 case oid.T_bool: 118 return s[0] == 't', nil 119 case oid.T_int8, oid.T_int4, oid.T_int2: 120 i, err := strconv.ParseInt(string(s), 10, 64) 121 if err != nil { 122 err = errors.New("pq: " + err.Error()) 123 } 124 return i, err 125 case oid.T_float4, oid.T_float8: 126 // We always use 64 bit parsing, regardless of whether the input text is for 127 // a float4 or float8, because clients expect float64s for all float datatypes 128 // and returning a 32-bit parsed float64 produces lossy results. 129 f, err := strconv.ParseFloat(string(s), 64) 130 if err != nil { 131 err = errors.New("pq: " + err.Error()) 132 } 133 return f, err 134 } 135 return s, nil 136 } 137 138 // appendEncodedText encodes item in text format as required by COPY 139 // and appends to buf 140 func appendEncodedText(buf []byte, x any) ([]byte, error) { 141 switch v := x.(type) { 142 case int64: 143 return strconv.AppendInt(buf, v, 10), nil 144 case float64: 145 return strconv.AppendFloat(buf, v, 'f', -1, 64), nil 146 case []byte: 147 encodedBytea := encodeBytea(v) 148 return appendEscapedText(buf, string(encodedBytea)), nil 149 case string: 150 return appendEscapedText(buf, v), nil 151 case bool: 152 return strconv.AppendBool(buf, v), nil 153 case time.Time: 154 return append(buf, formatTS(v)...), nil 155 case nil: 156 return append(buf, `\N`...), nil 157 default: 158 return nil, fmt.Errorf("pq: encode: unknown type for %T", v) 159 } 160 } 161 162 func appendEscapedText(buf []byte, text string) []byte { 163 escapeNeeded := false 164 startPos := 0 165 166 // check if we need to escape 167 for i := 0; i < len(text); i++ { 168 c := text[i] 169 if c == '\\' || c == '\n' || c == '\r' || c == '\t' { 170 escapeNeeded = true 171 startPos = i 172 break 173 } 174 } 175 if !escapeNeeded { 176 return append(buf, text...) 177 } 178 179 // copy till first char to escape, iterate the rest 180 result := append(buf, text[:startPos]...) 181 for i := startPos; i < len(text); i++ { 182 switch c := text[i]; c { 183 case '\\': 184 result = append(result, '\\', '\\') 185 case '\n': 186 result = append(result, '\\', 'n') 187 case '\r': 188 result = append(result, '\\', 'r') 189 case '\t': 190 result = append(result, '\\', 't') 191 default: 192 result = append(result, c) 193 } 194 } 195 return result 196 } 197 198 func parseTime(typ oid.Oid, s []byte) (time.Time, error) { 199 str := string(s) 200 201 f := "15:04:05" 202 if typ == oid.T_timetz { 203 f = "15:04:05-07" 204 // PostgreSQL just sends the hour if the minute and second is 0: 205 // 22:04:59+00 206 // 22:04:59+08 207 // 22:04:59+08:30 208 // 22:04:59+08:30:40 209 // 23:00:00.112321+02:12:13 210 // So add those to the format string. 211 c := strings.Count(str, ":") 212 if c > 3 { 213 f = "15:04:05-07:00:00" 214 } else if c > 2 { 215 f = "15:04:05-07:00" 216 } 217 } 218 219 // Go doesn't parse 24:00, so manually set that to midnight on Jan 2. 24:00 220 // is never with subseconds but may have a timezone: 221 // 24:00:00 222 // 24:00:00+08 223 // 24:00:00-08:01:01 224 var is2400Time bool 225 if strings.HasPrefix(str, "24:00:00") { 226 is2400Time = true 227 if len(str) > 8 { 228 str = "00:00:00" + str[8:] 229 } else { 230 str = "00:00:00" 231 } 232 } 233 234 t, err := time.Parse(f, str) 235 if err != nil { 236 return time.Time{}, errors.New("pq: " + err.Error()) 237 } 238 if is2400Time { 239 t = t.Add(24 * time.Hour) 240 } 241 // TODO(v2): it uses UTC, which it shouldn't. But I'm afraid changing it now 242 // will break people's code. 243 //if typ == oid.T_time { 244 // // Don't use UTC but time.FixedZone("", 0) 245 // t = t.In(globalLocationCache.getLocation(0)) 246 //} 247 return t, nil 248 } 249 250 var ( 251 infinityTSEnabled = false 252 infinityTSNegative time.Time 253 infinityTSPositive time.Time 254 ) 255 256 // EnableInfinityTs controls the handling of Postgres' "-infinity" and 257 // "infinity" "timestamp"s. 258 // 259 // If EnableInfinityTs is not called, "-infinity" and "infinity" will return 260 // []byte("-infinity") and []byte("infinity") respectively, and potentially 261 // cause error "sql: Scan error on column index 0: unsupported driver -> Scan 262 // pair: []uint8 -> *time.Time", when scanning into a time.Time value. 263 // 264 // Once EnableInfinityTs has been called, all connections created using this 265 // driver will decode Postgres' "-infinity" and "infinity" for "timestamp", 266 // "timestamp with time zone" and "date" types to the predefined minimum and 267 // maximum times, respectively. When encoding time.Time values, any time which 268 // equals or precedes the predefined minimum time will be encoded to 269 // "-infinity". Any values at or past the maximum time will similarly be 270 // encoded to "infinity". 271 // 272 // If EnableInfinityTs is called with negative >= positive, it will panic. 273 // Calling EnableInfinityTs after a connection has been established results in 274 // undefined behavior. If EnableInfinityTs is called more than once, it will 275 // panic. 276 func EnableInfinityTs(negative time.Time, positive time.Time) { 277 if infinityTSEnabled { 278 panic("pq: infinity timestamp already enabled") 279 } 280 if !negative.Before(positive) { 281 panic("pq: infinity timestamp: negative value must be smaller (before) than positive") 282 } 283 infinityTSEnabled = true 284 infinityTSNegative = negative 285 infinityTSPositive = positive 286 } 287 288 // Testing might want to toggle infinityTSEnabled 289 func disableInfinityTS() { 290 infinityTSEnabled = false 291 } 292 293 // This is a time function specific to the Postgres default DateStyle setting 294 // ("ISO, MDY"), the only one we currently support. This accounts for the 295 // discrepancies between the parsing available with time.Parse and the Postgres 296 // date formatting quirks. 297 func parseTS(currentLocation *time.Location, str string) (any, error) { 298 switch str { 299 case "-infinity": 300 if infinityTSEnabled { 301 return infinityTSNegative, nil 302 } 303 return []byte(str), nil 304 case "infinity": 305 if infinityTSEnabled { 306 return infinityTSPositive, nil 307 } 308 return []byte(str), nil 309 } 310 t, err := ParseTimestamp(currentLocation, str) 311 if err != nil { 312 err = errors.New("pq: " + err.Error()) 313 } 314 return t, err 315 } 316 317 // ParseTimestamp parses Postgres' text format. It returns a time.Time in 318 // currentLocation iff that time's offset agrees with the offset sent from the 319 // Postgres server. Otherwise, ParseTimestamp returns a time.Time with the fixed 320 // offset offset provided by the Postgres server. 321 func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { 322 return pqtime.Parse(currentLocation, str) 323 } 324 325 // formatTS formats t into a format postgres understands. 326 func formatTS(t time.Time) []byte { 327 if infinityTSEnabled { 328 // t <= -infinity : ! (t > -infinity) 329 if !t.After(infinityTSNegative) { 330 return []byte("-infinity") 331 } 332 // t >= infinity : ! (!t < infinity) 333 if !t.Before(infinityTSPositive) { 334 return []byte("infinity") 335 } 336 } 337 return FormatTimestamp(t) 338 } 339 340 // FormatTimestamp formats t into Postgres' text format for timestamps. 341 func FormatTimestamp(t time.Time) []byte { 342 return pqtime.Format(t) 343 } 344 345 // Parse a bytea value received from the server. Both "hex" and the legacy 346 // "escape" format are supported. 347 func parseBytea(s []byte) (result []byte, err error) { 348 // Hex format. 349 if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { 350 s = s[2:] // trim off leading "\\x" 351 result = make([]byte, hex.DecodedLen(len(s))) 352 _, err := hex.Decode(result, s) 353 if err != nil { 354 return nil, err 355 } 356 return result, nil 357 } 358 359 // Escape format. 360 for len(s) > 0 { 361 if s[0] == '\\' { 362 // escaped '\\' 363 if len(s) >= 2 && s[1] == '\\' { 364 result = append(result, '\\') 365 s = s[2:] 366 continue 367 } 368 369 // '\\' followed by an octal number 370 if len(s) < 4 { 371 return nil, fmt.Errorf("invalid bytea sequence %v", s) 372 } 373 r, err := strconv.ParseUint(string(s[1:4]), 8, 8) 374 if err != nil { 375 return nil, fmt.Errorf("could not parse bytea value: %w", err) 376 } 377 result = append(result, byte(r)) 378 s = s[4:] 379 } else { 380 // We hit an unescaped, raw byte. Try to read in as many as 381 // possible in one go. 382 i := bytes.IndexByte(s, '\\') 383 if i == -1 { 384 result = append(result, s...) 385 break 386 } 387 result = append(result, s[:i]...) 388 s = s[i:] 389 } 390 } 391 return result, nil 392 } 393 394 func encodeBytea(v []byte) (result []byte) { 395 result = make([]byte, 2+hex.EncodedLen(len(v))) 396 result[0] = '\\' 397 result[1] = 'x' 398 hex.Encode(result[2:], v) 399 return result 400 }