libeufin

Integration and sandbox testing for FinTech APIs and data formats
Log | Files | Refs | Submodules | README | LICENSE

server.kt (15585B)


      1 /*
      2  * This file is part of LibEuFin.
      3  * Copyright (C) 2024, 2025, 2026 Taler Systems S.A.
      4 
      5  * LibEuFin is free software; you can redistribute it and/or modify
      6  * it under the terms of the GNU Affero General Public License as
      7  * published by the Free Software Foundation; either version 3, or
      8  * (at your option) any later version.
      9 
     10  * LibEuFin is distributed in the hope that it will be useful, but
     11  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
     12  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Affero General
     13  * Public License for more details.
     14 
     15  * You should have received a copy of the GNU Affero General Public
     16  * License along with LibEuFin; see the file COPYING.  If not, see
     17  * <http://www.gnu.org/licenses/>
     18  */
     19 
     20 package tech.libeufin.common.api
     21 
     22 import io.github.smiley4.ktoropenapi.OpenApi
     23 import io.github.smiley4.ktoropenapi.OpenApiPlugin
     24 import io.github.smiley4.ktoropenapi.config.*
     25 import io.github.smiley4.ktoropenapi.openApi
     26 import io.github.smiley4.schemakenerator.serialization.SerializationSteps.analyzeTypeUsingKotlinxSerialization
     27 import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.compileReferencingRoot
     28 import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.generateSwaggerSchema
     29 import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.withTitle
     30 import io.github.smiley4.schemakenerator.swagger.SwaggerSteps.RequiredHandling
     31 import io.github.smiley4.schemakenerator.core.CoreSteps.addDiscriminatorProperty
     32 import io.github.smiley4.schemakenerator.core.CoreSteps.handleNameAnnotation
     33 import io.github.smiley4.schemakenerator.swagger.data.*
     34 import io.ktor.http.*
     35 import io.ktor.serialization.kotlinx.json.*
     36 import io.ktor.server.application.*
     37 import io.ktor.server.engine.*
     38 import io.ktor.server.cio.*
     39 import io.ktor.server.plugins.*
     40 import io.ktor.server.plugins.calllogging.*
     41 import io.ktor.server.plugins.contentnegotiation.*
     42 import io.ktor.server.plugins.forwardedheaders.*
     43 import io.ktor.server.plugins.statuspages.*
     44 import io.ktor.server.plugins.callid.*
     45 import io.ktor.server.request.*
     46 import io.ktor.server.response.*
     47 import io.ktor.server.routing.*
     48 import io.ktor.utils.io.*
     49 import io.ktor.util.*
     50 import io.ktor.util.pipeline.*
     51 import io.ktor.http.content.*
     52 import kotlinx.serialization.ExperimentalSerializationApi
     53 import kotlinx.serialization.json.Json
     54 import org.postgresql.util.PSQLState
     55 import org.slf4j.Logger
     56 import org.slf4j.event.Level
     57 import tech.libeufin.common.*
     58 import tech.libeufin.common.db.SERIALIZATION_ERROR
     59 import java.net.InetAddress
     60 import java.sql.SQLException
     61 import java.util.zip.DataFormatException
     62 import java.util.zip.Inflater
     63 
     64 /** Used to store the raw body */
     65 private val RAW_BODY = AttributeKey<ByteArray>("RAW_BODY")
     66 
     67 /** Used to set custom body limit */
     68 val BODY_LIMIT = AttributeKey<Int>("BODY_LIMIT")
     69 
     70 /** Get call raw body */
     71 val ApplicationCall.rawBody: ByteArray get() = attributes.getOrNull(RAW_BODY) ?: ByteArray(0)
     72 
     73 /**
     74  * This plugin apply Taler specific logic
     75  * It checks for body length limit and inflates the requests that have "Content-Encoding: deflate"
     76  * It logs incoming requests and their details
     77  */
     78 fun talerPlugin(logger: Logger): ApplicationPlugin<Unit> {
     79     return createApplicationPlugin("TalerPlugin") {
     80         onCall { call ->
     81             // Handle CORS
     82             call.response.header(HttpHeaders.AccessControlAllowOrigin, "*")
     83             // Handle CORS preflight
     84             if (call.request.httpMethod == HttpMethod.Options) {
     85                 call.response.header(HttpHeaders.AccessControlAllowHeaders, "*")
     86                 call.response.header(HttpHeaders.AccessControlAllowMethods, "*")
     87                 call.respond(HttpStatusCode.NoContent)
     88                 return@onCall
     89             }
     90 
     91             // Log incoming transaction
     92             val requestCall = buildString {
     93                 val path = call.request.path()
     94                 append(call.request.httpMethod.value)
     95                 append(' ')
     96                 append(call.request.path())
     97                 val query = call.request.queryString()
     98                 if (query.isNotEmpty()) {
     99                     append('?')
    100                     append(query)
    101                 }
    102             }
    103             logger.info(requestCall)
    104         }
    105         onCallReceive { call ->
    106             val bodyLimit = call.attributes.getOrNull(BODY_LIMIT) ?: MAX_BODY_LENGTH
    107             // Check content length if present and wellformed
    108             val contentLenght = call.request.headers[HttpHeaders.ContentLength]?.toIntOrNull()
    109             if (contentLenght != null && contentLenght > bodyLimit)
    110                 throw bodyOverflow("Body is suspiciously big > ${bodyLimit}B")
    111 
    112             // Else check while reading and decompressing the body
    113             transformBody { body ->
    114                 val bytes = ByteArray(bodyLimit + 1)
    115                 var read = 0
    116                 when (val encoding = call.request.headers[HttpHeaders.ContentEncoding])  {
    117                     "deflate" -> {
    118                         // Decompress and check decompressed length
    119                         val inflater = Inflater()
    120                         while (!body.isClosedForRead) {
    121                             body.read { buf ->
    122                                 inflater.setInput(buf)
    123                                 try {
    124                                     read += inflater.inflate(bytes, read, bytes.size - read)
    125                                 } catch (e: DataFormatException) {
    126                                     logger.error("Deflated request failed to inflate: ${e.message}")
    127                                     throw badRequest(
    128                                         "Could not inflate request",
    129                                         TalerErrorCode.GENERIC_COMPRESSION_INVALID
    130                                     )
    131                                 }
    132                             }
    133                             if (read > bodyLimit)
    134                                 throw bodyOverflow("Decompressed body is suspiciously big > ${bodyLimit}B")
    135                         }
    136                     }
    137                     null -> {
    138                         // Check body length
    139                         while (true) {
    140                             val new = body.readAvailable(bytes, read, bytes.size - read)
    141                             if (new == -1) break // Channel is closed
    142                             read += new
    143                             if (read > bodyLimit)
    144                                 throw bodyOverflow("Body is suspiciously big > ${bodyLimit}B")
    145                         }
    146                     } 
    147                     else -> throw unsupportedMediaType(
    148                         "Content encoding '$encoding' not supported, expected plain or deflate",
    149                         TalerErrorCode.GENERIC_COMPRESSION_INVALID
    150                     )
    151                 }
    152                 logger.trace {
    153                     "request ${bytes.sliceArray(0 until read).asUtf8()}"
    154                 }
    155                 call.attributes.put(RAW_BODY, bytes)
    156                 ByteReadChannel(bytes, 0, read)
    157             }
    158         }
    159     }
    160 }
    161 
    162 data class OpenApiInfo(
    163     val title: String,
    164     val version: String,
    165     val description: String? = null,
    166     val securityConfig: SecurityConfig.() -> Unit
    167 )
    168 
    169 /** Set up web server handlers for a Taler API */
    170 fun Application.talerApi(logger: Logger, openApiInfo: OpenApiInfo? = null, serveSpec: Boolean = false, routes: Routing.() -> Unit) {
    171     if (openApiInfo != null) {
    172         install(OpenApi) {
    173             info {
    174                 title = openApiInfo.title
    175                 version = openApiInfo.version
    176                 description = openApiInfo.description
    177             }
    178             server {
    179                 url = "/"
    180                 description = "Same host"
    181             }
    182             security(openApiInfo.securityConfig)
    183             outputFormat = OutputFormat.YAML
    184             schemas {
    185                 generator = { type ->
    186                     type
    187                         .analyzeTypeUsingKotlinxSerialization()
    188                         .handleNameAnnotation()
    189                         .addDiscriminatorProperty("type")
    190                         .generateSwaggerSchema {
    191                             nullables = RequiredHandling.NON_REQUIRED
    192                             optionals = RequiredHandling.REQUIRED
    193                         }
    194                         .withTitle(TitleType.MINIMAL)
    195                         .compileReferencingRoot(
    196                             explicitNullTypes = false,
    197                             pathType = RefType.OPENAPI_MINIMAL
    198                         )
    199                 }
    200             }
    201         }
    202     }
    203     install(CallId) {
    204         generate(10, "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
    205         verify { true }
    206     }
    207     install(CallLogging) {
    208         callIdMdc("call-id")
    209         level = Level.INFO
    210         this.logger = logger
    211         format { call ->
    212             val status = call.response.status()
    213             val msg = call.logMsg()
    214             if (msg != null) {
    215                 "${status?.value} ${call.processingTimeMillis()}ms: $msg"
    216             } else {
    217                 "${status?.value} ${call.processingTimeMillis()}ms"
    218             }
    219         }
    220     }
    221     install(XForwardedHeaders)
    222     install(talerPlugin(logger))
    223     install(IgnoreTrailingSlash)
    224     install(ContentNegotiation) {
    225         json(Json {
    226             @OptIn(ExperimentalSerializationApi::class)
    227             explicitNulls = false
    228             encodeDefaults = true
    229             ignoreUnknownKeys = true
    230         })
    231     }
    232     install(StatusPages) {
    233         status(HttpStatusCode.NotFound) { call, status ->
    234             call.err(
    235                 status,
    236                 "There is no endpoint defined for the URL provided by the client. Check if you used the correct URL and/or file a report with the developers of the client software.",
    237                 TalerErrorCode.GENERIC_ENDPOINT_UNKNOWN,
    238                 null
    239             )
    240         }
    241         status(HttpStatusCode.MethodNotAllowed) { call, status ->
    242             call.err(
    243                 status,
    244                 "The HTTP method used is invalid for this endpoint. This is likely a bug in the client implementation. Check if you are using the latest available version and/or file a report with the developers.",
    245                 TalerErrorCode.GENERIC_METHOD_INVALID,
    246                 null
    247             )
    248         }
    249         exception<Exception> { call, cause ->
    250             logger.debug("", cause)
    251             when (cause) {
    252                 is ApiException -> call.err(cause, null)
    253                 is SQLException -> {
    254                     if (SERIALIZATION_ERROR.contains(cause.sqlState)) {
    255                         call.err(
    256                             HttpStatusCode.InternalServerError,
    257                             "Transaction serialization failure",
    258                             TalerErrorCode.BANK_SOFT_EXCEPTION,
    259                             cause
    260                         )
    261                     } else {
    262                         call.err(
    263                             HttpStatusCode.InternalServerError,
    264                             "Unexpected sql error with state ${cause.sqlState}",
    265                             TalerErrorCode.BANK_UNMANAGED_EXCEPTION,
    266                             cause
    267                         )
    268                     }
    269                 }
    270                 is BadRequestException -> {
    271                     /**
    272                      * NOTE: extracting the root cause helps with JSON error messages,
    273                      * because they mention the particular way they are invalid, but OTOH
    274                      * it loses (by getting null) other error messages, like for example
    275                      * the one from MissingRequestParameterException.  Therefore, in order
    276                      * to get the most detailed message, we must consider BOTH sides:
    277                      * the 'cause' AND its root cause!
    278                      */
    279                     var rootCause: Throwable? = cause.cause
    280                     while (rootCause?.cause != null)
    281                         rootCause = rootCause.cause
    282                     // Telling apart invalid JSON vs missing parameter vs invalid parameter.
    283                     val errorCode = when {
    284                         cause is MissingRequestParameterException ->
    285                             TalerErrorCode.GENERIC_PARAMETER_MISSING
    286                         cause is ParameterConversionException ->
    287                             TalerErrorCode.GENERIC_PARAMETER_MALFORMED
    288                         rootCause is CommonError -> when (rootCause) {
    289                             is CommonError.AmountFormat -> TalerErrorCode.BANK_BAD_FORMAT_AMOUNT
    290                             is CommonError.AmountNumberTooBig -> TalerErrorCode.BANK_NUMBER_TOO_BIG
    291                             is CommonError.Payto -> TalerErrorCode.GENERIC_JSON_INVALID
    292                         }
    293                         else -> TalerErrorCode.GENERIC_JSON_INVALID
    294                     }
    295                     call.err(
    296                         HttpStatusCode.BadRequest,
    297                         rootCause?.message,
    298                         errorCode,
    299                         null
    300                     )
    301                 }
    302                 is CommonError -> {
    303                     val errorCode = when (cause) {
    304                         is CommonError.AmountFormat -> TalerErrorCode.BANK_BAD_FORMAT_AMOUNT
    305                         is CommonError.AmountNumberTooBig -> TalerErrorCode.BANK_NUMBER_TOO_BIG
    306                         is CommonError.Payto -> TalerErrorCode.GENERIC_JSON_INVALID
    307                     }
    308                     call.err(
    309                         HttpStatusCode.BadRequest,
    310                         cause.message,
    311                         errorCode,
    312                         null
    313                     )
    314                 }
    315                 else -> {
    316                     call.err(
    317                         HttpStatusCode.InternalServerError,
    318                         cause.message,
    319                         TalerErrorCode.BANK_UNMANAGED_EXCEPTION,
    320                         cause
    321                     )
    322                 }
    323             }
    324         }
    325     }
    326     val phase = PipelinePhase("phase")
    327     sendPipeline.insertPhaseBefore(ApplicationSendPipeline.Engine, phase)
    328     sendPipeline.intercept(phase) { response ->
    329         if (logger.isTraceEnabled) {
    330             if (response is OutgoingContent.ByteArrayContent) {
    331                 logger.trace("response ${String(response.bytes())}")
    332             }
    333         }
    334         
    335     }
    336     routing {
    337         routes()
    338         if (serveSpec) {
    339             route("openapi.yaml") {
    340                 openApi()
    341             }
    342         }
    343     }
    344 }
    345 
    346 // Dirty local variable to stop the server in test TODO remove this ugly hack
    347 var engine: ApplicationEngine? = null
    348 
    349 fun serve(cfg: tech.libeufin.common.ServerConfig, logger: Logger, api: Application.() -> Unit) {
    350     val server = embeddedServer(CIO,
    351         configure = {
    352             when (cfg) {
    353                 is ServerConfig.Tcp -> {
    354                     for (addr in InetAddress.getAllByName(cfg.addr)) {
    355                         logger.info("Listening on ${addr.hostAddress}:${cfg.port}")
    356                         connector {
    357                             port = cfg.port
    358                             host = addr.hostAddress
    359                         }
    360                     }
    361                 }
    362                 is ServerConfig.Unix -> {
    363                     logger.info("Listening on ${cfg.path}")
    364                     unixConnector(cfg.path.toString())
    365                 }
    366             }
    367         },
    368         module = api
    369     )
    370     engine = server.engine
    371     server.start(wait = true)
    372 }