base64.rs (5629B)
1 /* 2 This file is part of TALER 3 Copyright (C) 2026 Taler Systems SA 4 5 TALER is free software; you can redistribute it and/or modify it under the 6 terms of the GNU Affero General Public License as published by the Free Software 7 Foundation; either version 3, or (at your option) any later version. 8 9 TALER is distributed in the hope that it will be useful, but WITHOUT ANY 10 WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 11 A PARTICULAR PURPOSE. See the GNU Affero General Public License for more details. 12 13 You should have received a copy of the GNU Affero General Public License along with 14 TALER; see the file COPYING. If not, see <http://www.gnu.org/licenses/> 15 */ 16 17 use std::fmt::Display; 18 19 pub const BASE64_ALPHABET: &[u8] = 20 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; 21 22 /** Encoded bytes len of base64 with padding */ 23 #[inline] 24 const fn encoded_len(len: usize) -> usize { 25 (len * 4).div_ceil(3).next_multiple_of(4) 26 } 27 28 /** Encode a chunk using base64 */ 29 #[inline(always)] 30 fn encode_chunk(chunk: &[u8], encoded: &mut [u8]) { 31 let mut buf = [0u8; 3]; 32 for (i, &b) in chunk.iter().enumerate() { 33 buf[i] = b; 34 } 35 encoded[0] = BASE64_ALPHABET[((buf[0] & 0xFC) >> 2) as usize]; 36 encoded[1] = BASE64_ALPHABET[(((buf[0] & 0x03) << 4) | ((buf[1] & 0xF0) >> 4)) as usize]; 37 if chunk.len() > 1 { 38 encoded[2] = BASE64_ALPHABET[(((buf[1] & 0x0F) << 2) | ((buf[2] & 0xC0) >> 6)) as usize]; 39 } 40 if chunk.len() > 2 { 41 encoded[3] = BASE64_ALPHABET[(buf[2] & 0x3F) as usize]; 42 } 43 } 44 45 /** Encode bytes using standard base64 with `=` padding */ 46 pub fn encode(bytes: impl AsRef<[u8]>) -> String { 47 let bytes = bytes.as_ref(); 48 let mut buf = vec![b'='; encoded_len(bytes.len())]; 49 50 for (chunk, buf) in bytes.chunks(3).zip(buf.chunks_exact_mut(4)) { 51 encode_chunk(chunk, buf) 52 } 53 54 // SAFETY: only contains ASCII characters from BASE64_ALPHABET or b'=' 55 unsafe { String::from_utf8_unchecked(buf) } 56 } 57 58 /** Format bytes using standard base64 with `=` padding */ 59 pub fn fmt(bytes: impl AsRef<[u8]>) -> impl Display { 60 std::fmt::from_fn(move |f| { 61 for chunk in bytes.as_ref().chunks(3) { 62 let mut tmp = [0u8; 3]; 63 for (i, &b) in chunk.iter().enumerate() { 64 tmp[i] = b; 65 } 66 67 let mut out = [b'='; 4]; 68 out[0] = BASE64_ALPHABET[((tmp[0] & 0xFC) >> 2) as usize]; 69 out[1] = BASE64_ALPHABET[(((tmp[0] & 0x03) << 4) | ((tmp[1] & 0xF0) >> 4)) as usize]; 70 if chunk.len() > 1 { 71 out[2] = 72 BASE64_ALPHABET[(((tmp[1] & 0x0F) << 2) | ((tmp[2] & 0xC0) >> 6)) as usize]; 73 } 74 if chunk.len() > 2 { 75 out[3] = BASE64_ALPHABET[(tmp[2] & 0x3F) as usize]; 76 } 77 78 // SAFETY: out contains only ASCII characters from BASE64_ALPHABET or b'=' 79 f.write_str(unsafe { std::str::from_utf8_unchecked(&out) })?; 80 } 81 Ok(()) 82 }) 83 } 84 85 #[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] 86 pub enum Base64Error { 87 #[error("invalid base64 format")] 88 Format, 89 #[error("invalid length: base64 input must be a multiple of 4")] 90 Length, 91 } 92 93 const BASE64_INV: [u8; 256] = { 94 let mut table = [255u8; 256]; 95 let mut i = 0; 96 while i < 64 { 97 table[BASE64_ALPHABET[i] as usize] = i as u8; 98 i += 1; 99 } 100 table 101 }; 102 103 /** Unpadded decoded length from a padded base64 string */ 104 fn decoded_len(encoded: &[u8]) -> usize { 105 let padding = encoded.iter().rev().take_while(|&&b| b == b'=').count(); 106 encoded.len() * 3 / 4 - padding 107 } 108 109 /** Decode a standard base64 string (with `=` padding) */ 110 pub fn decode(encoded: impl AsRef<[u8]>) -> Result<Vec<u8>, Base64Error> { 111 let encoded = encoded.as_ref(); 112 if encoded.len() % 4 != 0 { 113 return Err(Base64Error::Length); 114 } 115 116 let out_len = decoded_len(encoded); 117 let mut decoded = Vec::with_capacity(out_len); 118 let mut invalid = false; 119 120 for chunk in encoded.chunks(4) { 121 let mut buf = [0u8; 4]; 122 // Lookup chunk 123 for (i, &b) in chunk.iter().enumerate() { 124 buf[i] = if b == b'=' { 0 } else { BASE64_INV[b as usize] }; 125 } 126 127 // Check chunk validity 128 invalid |= buf.contains(&255); 129 130 // Decode chunk 131 decoded.push((buf[0] << 2) | (buf[1] >> 4)); 132 if chunk[2] != b'=' { 133 decoded.push((buf[1] << 4) | (buf[2] >> 2)); 134 } 135 if chunk[3] != b'=' { 136 decoded.push((buf[2] << 6) | buf[3]); 137 } 138 } 139 140 if invalid { 141 return Err(Base64Error::Format); 142 } 143 144 Ok(decoded) 145 } 146 147 #[cfg(test)] 148 mod test { 149 use crate::encoding::base64::{Base64Error, decode, encode, fmt}; 150 151 #[test] 152 fn base64() { 153 // RFC test vectors 154 for (decoded, encoded) in [ 155 ("", ""), 156 ("f", "Zg=="), 157 ("fo", "Zm8="), 158 ("foo", "Zm9v"), 159 ("foob", "Zm9vYg=="), 160 ("fooba", "Zm9vYmE="), 161 ("foobar", "Zm9vYmFy"), 162 ] { 163 assert_eq!(encode(decoded.as_bytes()), encoded); 164 assert_eq!(fmt(decoded).to_string(), encoded); 165 assert_eq!(decode(encoded.as_bytes()).unwrap(), decoded.as_bytes()); 166 } 167 168 // Invalid length 169 assert_eq!(decode(b"Zg="), Err(Base64Error::Length)); 170 assert_eq!(decode(b"Z"), Err(Base64Error::Length)); 171 172 // Invalid characters 173 assert_eq!(decode(b"Zg=!"), Err(Base64Error::Format)); 174 assert_eq!(decode(b"Z\x00=="), Err(Base64Error::Format)); 175 } 176 }