taler-rust

GNU Taler code in Rust. Largely core banking integrations.
Log | Files | Refs | Submodules | README | LICENSE

base64.rs (5629B)


      1 /*
      2   This file is part of TALER
      3   Copyright (C) 2026 Taler Systems SA
      4 
      5   TALER is free software; you can redistribute it and/or modify it under the
      6   terms of the GNU Affero General Public License as published by the Free Software
      7   Foundation; either version 3, or (at your option) any later version.
      8 
      9   TALER is distributed in the hope that it will be useful, but WITHOUT ANY
     10   WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
     11   A PARTICULAR PURPOSE.  See the GNU Affero General Public License for more details.
     12 
     13   You should have received a copy of the GNU Affero General Public License along with
     14   TALER; see the file COPYING.  If not, see <http://www.gnu.org/licenses/>
     15 */
     16 
     17 use std::fmt::Display;
     18 
     19 pub const BASE64_ALPHABET: &[u8] =
     20     b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
     21 
     22 /** Encoded bytes len of base64 with padding */
     23 #[inline]
     24 const fn encoded_len(len: usize) -> usize {
     25     (len * 4).div_ceil(3).next_multiple_of(4)
     26 }
     27 
     28 /** Encode a chunk using base64 */
     29 #[inline(always)]
     30 fn encode_chunk(chunk: &[u8], encoded: &mut [u8]) {
     31     let mut buf = [0u8; 3];
     32     for (i, &b) in chunk.iter().enumerate() {
     33         buf[i] = b;
     34     }
     35     encoded[0] = BASE64_ALPHABET[((buf[0] & 0xFC) >> 2) as usize];
     36     encoded[1] = BASE64_ALPHABET[(((buf[0] & 0x03) << 4) | ((buf[1] & 0xF0) >> 4)) as usize];
     37     if chunk.len() > 1 {
     38         encoded[2] = BASE64_ALPHABET[(((buf[1] & 0x0F) << 2) | ((buf[2] & 0xC0) >> 6)) as usize];
     39     }
     40     if chunk.len() > 2 {
     41         encoded[3] = BASE64_ALPHABET[(buf[2] & 0x3F) as usize];
     42     }
     43 }
     44 
     45 /** Encode bytes using standard base64 with `=` padding */
     46 pub fn encode(bytes: impl AsRef<[u8]>) -> String {
     47     let bytes = bytes.as_ref();
     48     let mut buf = vec![b'='; encoded_len(bytes.len())];
     49 
     50     for (chunk, buf) in bytes.chunks(3).zip(buf.chunks_exact_mut(4)) {
     51         encode_chunk(chunk, buf)
     52     }
     53 
     54     // SAFETY: only contains ASCII characters from BASE64_ALPHABET or b'='
     55     unsafe { String::from_utf8_unchecked(buf) }
     56 }
     57 
     58 /** Format bytes using standard base64 with `=` padding */
     59 pub fn fmt(bytes: impl AsRef<[u8]>) -> impl Display {
     60     std::fmt::from_fn(move |f| {
     61         for chunk in bytes.as_ref().chunks(3) {
     62             let mut tmp = [0u8; 3];
     63             for (i, &b) in chunk.iter().enumerate() {
     64                 tmp[i] = b;
     65             }
     66 
     67             let mut out = [b'='; 4];
     68             out[0] = BASE64_ALPHABET[((tmp[0] & 0xFC) >> 2) as usize];
     69             out[1] = BASE64_ALPHABET[(((tmp[0] & 0x03) << 4) | ((tmp[1] & 0xF0) >> 4)) as usize];
     70             if chunk.len() > 1 {
     71                 out[2] =
     72                     BASE64_ALPHABET[(((tmp[1] & 0x0F) << 2) | ((tmp[2] & 0xC0) >> 6)) as usize];
     73             }
     74             if chunk.len() > 2 {
     75                 out[3] = BASE64_ALPHABET[(tmp[2] & 0x3F) as usize];
     76             }
     77 
     78             // SAFETY: out contains only ASCII characters from BASE64_ALPHABET or b'='
     79             f.write_str(unsafe { std::str::from_utf8_unchecked(&out) })?;
     80         }
     81         Ok(())
     82     })
     83 }
     84 
     85 #[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)]
     86 pub enum Base64Error {
     87     #[error("invalid base64 format")]
     88     Format,
     89     #[error("invalid length: base64 input must be a multiple of 4")]
     90     Length,
     91 }
     92 
     93 const BASE64_INV: [u8; 256] = {
     94     let mut table = [255u8; 256];
     95     let mut i = 0;
     96     while i < 64 {
     97         table[BASE64_ALPHABET[i] as usize] = i as u8;
     98         i += 1;
     99     }
    100     table
    101 };
    102 
    103 /** Unpadded decoded length from a padded base64 string */
    104 fn decoded_len(encoded: &[u8]) -> usize {
    105     let padding = encoded.iter().rev().take_while(|&&b| b == b'=').count();
    106     encoded.len() * 3 / 4 - padding
    107 }
    108 
    109 /** Decode a standard base64 string (with `=` padding) */
    110 pub fn decode(encoded: impl AsRef<[u8]>) -> Result<Vec<u8>, Base64Error> {
    111     let encoded = encoded.as_ref();
    112     if encoded.len() % 4 != 0 {
    113         return Err(Base64Error::Length);
    114     }
    115 
    116     let out_len = decoded_len(encoded);
    117     let mut decoded = Vec::with_capacity(out_len);
    118     let mut invalid = false;
    119 
    120     for chunk in encoded.chunks(4) {
    121         let mut buf = [0u8; 4];
    122         // Lookup chunk
    123         for (i, &b) in chunk.iter().enumerate() {
    124             buf[i] = if b == b'=' { 0 } else { BASE64_INV[b as usize] };
    125         }
    126 
    127         // Check chunk validity
    128         invalid |= buf.contains(&255);
    129 
    130         // Decode chunk
    131         decoded.push((buf[0] << 2) | (buf[1] >> 4));
    132         if chunk[2] != b'=' {
    133             decoded.push((buf[1] << 4) | (buf[2] >> 2));
    134         }
    135         if chunk[3] != b'=' {
    136             decoded.push((buf[2] << 6) | buf[3]);
    137         }
    138     }
    139 
    140     if invalid {
    141         return Err(Base64Error::Format);
    142     }
    143 
    144     Ok(decoded)
    145 }
    146 
    147 #[cfg(test)]
    148 mod test {
    149     use crate::encoding::base64::{Base64Error, decode, encode, fmt};
    150 
    151     #[test]
    152     fn base64() {
    153         // RFC test vectors
    154         for (decoded, encoded) in [
    155             ("", ""),
    156             ("f", "Zg=="),
    157             ("fo", "Zm8="),
    158             ("foo", "Zm9v"),
    159             ("foob", "Zm9vYg=="),
    160             ("fooba", "Zm9vYmE="),
    161             ("foobar", "Zm9vYmFy"),
    162         ] {
    163             assert_eq!(encode(decoded.as_bytes()), encoded);
    164             assert_eq!(fmt(decoded).to_string(), encoded);
    165             assert_eq!(decode(encoded.as_bytes()).unwrap(), decoded.as_bytes());
    166         }
    167 
    168         // Invalid length
    169         assert_eq!(decode(b"Zg="), Err(Base64Error::Length));
    170         assert_eq!(decode(b"Z"), Err(Base64Error::Length));
    171 
    172         // Invalid characters
    173         assert_eq!(decode(b"Zg=!"), Err(Base64Error::Format));
    174         assert_eq!(decode(b"Z\x00=="), Err(Base64Error::Format));
    175     }
    176 }