Reinstall limiter, add buffering, fix TLS setup

This commit is contained in:
carbotaniuman 2020-06-09 14:21:18 -05:00
parent d1629e9f5c
commit 35ea86c6ac
5 changed files with 85 additions and 17 deletions

View file

@ -7,7 +7,7 @@ plugins {
} }
group = "com.mangadex" group = "com.mangadex"
version = "1.0.0-rc8" version = "1.0.0-rc9"
mainClassName = "mdnet.base.MangaDexClient" mainClassName = "mdnet.base.MangaDexClient"
repositories { repositories {

View file

@ -6,4 +6,7 @@ public class Constants {
public static final int CLIENT_BUILD = 2; public static final int CLIENT_BUILD = 2;
public static final String CLIENT_VERSION = "1.0"; public static final String CLIENT_VERSION = "1.0";
public static final Duration MAX_AGE_CACHE = Duration.ofDays(14); public static final Duration MAX_AGE_CACHE = Duration.ofDays(14);
public static final int MAX_CONCURRENT_CONNECTIONS = 2;
public static final String OVERLOADED_MESSAGE = "This server is experiencing a surge in connections. Please try again later.";
} }

View file

@ -23,6 +23,8 @@ import org.http4k.routing.routes
import org.http4k.server.Http4kServer import org.http4k.server.Http4kServer
import org.http4k.server.asServer import org.http4k.server.asServer
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.io.BufferedInputStream
import java.io.BufferedOutputStream
import java.io.InputStream import java.io.InputStream
import java.security.MessageDigest import java.security.MessageDigest
import java.time.ZoneOffset import java.time.ZoneOffset
@ -37,10 +39,15 @@ import javax.crypto.CipherOutputStream
import javax.crypto.spec.SecretKeySpec import javax.crypto.spec.SecretKeySpec
private val LOGGER = LoggerFactory.getLogger("Application") private val LOGGER = LoggerFactory.getLogger("Application")
private val THREADS_TO_ALLOCATE = Runtime.getRuntime().availableProcessors() * 30 / 2 ;
fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSettings: ClientSettings, statistics: AtomicReference<Statistics>): Http4kServer { fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSettings: ClientSettings, statistics: AtomicReference<Statistics>): Http4kServer {
val executor = Executors.newCachedThreadPool() val executor = Executors.newCachedThreadPool()
if (LOGGER.isInfoEnabled) {
LOGGER.info("Starting ApacheClient with {} threads", THREADS_TO_ALLOCATE)
}
val client = ApacheClient(responseBodyMode = BodyMode.Stream, client = HttpClients.custom() val client = ApacheClient(responseBodyMode = BodyMode.Stream, client = HttpClients.custom()
.setDefaultRequestConfig(RequestConfig.custom() .setDefaultRequestConfig(RequestConfig.custom()
.setCookieSpec(CookieSpecs.IGNORE_COOKIES) .setCookieSpec(CookieSpecs.IGNORE_COOKIES)
@ -48,8 +55,8 @@ fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSetting
.setSocketTimeout(3000) .setSocketTimeout(3000)
.setConnectionRequestTimeout(3000) .setConnectionRequestTimeout(3000)
.build()) .build())
.setMaxConnTotal(75) .setMaxConnTotal(THREADS_TO_ALLOCATE)
.setMaxConnPerRoute(75) .setMaxConnPerRoute(THREADS_TO_ALLOCATE)
.build()) .build())
val app = { dataSaver: Boolean -> val app = { dataSaver: Boolean ->
@ -122,7 +129,7 @@ fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSetting
} }
respondWithImage( respondWithImage(
CipherInputStream(snapshot.getInputStream(0), getRc4(rc4Bytes)), CipherInputStream(BufferedInputStream(snapshot.getInputStream(0)), getRc4(rc4Bytes)),
snapshot.getLength(0).toString(), snapshot.getString(1), snapshot.getString(2) snapshot.getLength(0).toString(), snapshot.getString(1), snapshot.getString(2)
) )
} }
@ -161,19 +168,19 @@ fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSetting
val tee = CachingInputStream( val tee = CachingInputStream(
mdResponse.body.stream, mdResponse.body.stream,
executor, CipherOutputStream(editor.newOutputStream(0), getRc4(rc4Bytes)) executor, CipherOutputStream(BufferedOutputStream(editor.newOutputStream(0)), getRc4(rc4Bytes))
) { ) {
// Note: if neither of the options get called/are in the log // Note: if neither of the options get called/are in the log
// check that tee gets closed and for exceptions in this lambda // check that tee gets closed and for exceptions in this lambda
if (editor.getLength(0) == contentLength.toLong()) { if (editor.getLength(0) == contentLength.toLong()) {
if (LOGGER.isInfoEnabled) { if (LOGGER.isInfoEnabled) {
LOGGER.info("Cache download $sanitizedUri committed") LOGGER.info("Cache download for $sanitizedUri committed")
} }
editor.commit() editor.commit()
} else { } else {
if (LOGGER.isInfoEnabled) { if (LOGGER.isInfoEnabled) {
LOGGER.info("Cache download $sanitizedUri aborted") LOGGER.info("Cache download for $sanitizedUri aborted")
} }
editor.abort() editor.abort()

View file

@ -36,15 +36,6 @@ private const val PKCS_1_PEM_FOOTER = "-----END RSA PRIVATE KEY-----"
private const val PKCS_8_PEM_HEADER = "-----BEGIN PRIVATE KEY-----" private const val PKCS_8_PEM_HEADER = "-----BEGIN PRIVATE KEY-----"
private const val PKCS_8_PEM_FOOTER = "-----END PRIVATE KEY-----" private const val PKCS_8_PEM_FOOTER = "-----END PRIVATE KEY-----"
fun getX509Cert(certificate: String): X509Certificate {
val targetStream: InputStream = ByteArrayInputStream(certificate.toByteArray())
return CertificateFactory.getInstance("X509").generateCertificate(targetStream) as X509Certificate
}
fun getPrivateKey(privateKey: String): PrivateKey {
return loadKey(privateKey)!!
}
fun loadKey(keyDataString: String): PrivateKey? { fun loadKey(keyDataString: String): PrivateKey? {
if (keyDataString.contains(PKCS_1_PEM_HEADER)) { if (keyDataString.contains(PKCS_1_PEM_HEADER)) {
// OpenSSL / PKCS#1 Base64 PEM encoded file // OpenSSL / PKCS#1 Base64 PEM encoded file

View file

@ -1,8 +1,10 @@
package mdnet.base package mdnet.base
import io.netty.bootstrap.ServerBootstrap import io.netty.bootstrap.ServerBootstrap
import io.netty.buffer.Unpooled
import io.netty.channel.ChannelFactory import io.netty.channel.ChannelFactory
import io.netty.channel.ChannelFuture import io.netty.channel.ChannelFuture
import io.netty.channel.ChannelHandler
import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelHandlerContext
import io.netty.channel.ChannelInboundHandlerAdapter import io.netty.channel.ChannelInboundHandlerAdapter
import io.netty.channel.ChannelInitializer import io.netty.channel.ChannelInitializer
@ -12,10 +14,16 @@ import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.DecoderException import io.netty.handler.codec.DecoderException
import io.netty.handler.codec.http.DefaultFullHttpResponse
import io.netty.handler.codec.http.HttpHeaderNames
import io.netty.handler.codec.http.HttpObjectAggregator import io.netty.handler.codec.http.HttpObjectAggregator
import io.netty.handler.codec.http.HttpResponseStatus
import io.netty.handler.codec.http.HttpServerCodec import io.netty.handler.codec.http.HttpServerCodec
import io.netty.handler.codec.http.HttpUtil
import io.netty.handler.codec.http.HttpVersion
import io.netty.handler.ssl.OptionalSslHandler import io.netty.handler.ssl.OptionalSslHandler
import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.ssl.SslHandler
import io.netty.handler.stream.ChunkedWriteHandler import io.netty.handler.stream.ChunkedWriteHandler
import io.netty.handler.traffic.GlobalTrafficShapingHandler import io.netty.handler.traffic.GlobalTrafficShapingHandler
import io.netty.handler.traffic.TrafficCounter import io.netty.handler.traffic.TrafficCounter
@ -24,14 +32,58 @@ import org.http4k.server.Http4kChannelHandler
import org.http4k.server.Http4kServer import org.http4k.server.Http4kServer
import org.http4k.server.ServerConfig import org.http4k.server.ServerConfig
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.io.ByteArrayInputStream
import java.io.IOException import java.io.IOException
import java.io.InputStream
import java.net.InetSocketAddress import java.net.InetSocketAddress
import java.nio.charset.StandardCharsets
import java.security.PrivateKey
import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.atomic.AtomicReference
import javax.net.ssl.SSLException import javax.net.ssl.SSLException
private val LOGGER = LoggerFactory.getLogger("Application") private val LOGGER = LoggerFactory.getLogger("Application")
@ChannelHandler.Sharable
class ConnectionCounter : ChannelInboundHandlerAdapter() {
private val connections = AtomicInteger()
override fun channelActive(ctx: ChannelHandlerContext) {
val sslHandler = ctx.pipeline()[SslHandler::class.java]
if (sslHandler != null) {
sslHandler.handshakeFuture().addListener {
handleConnection(ctx)
}
} else {
handleConnection(ctx)
}
}
private fun handleConnection(ctx: ChannelHandlerContext) {
if (connections.incrementAndGet() <= Constants.MAX_CONCURRENT_CONNECTIONS) {
super.channelActive(ctx)
} else {
val response = Unpooled.copiedBuffer(Constants.OVERLOADED_MESSAGE, StandardCharsets.UTF_8)
val res =
DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE, response)
res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/html; charset=UTF-8")
HttpUtil.setContentLength(res, response.readableBytes().toLong())
ctx.writeAndFlush(res)
ctx.close()
}
}
override fun channelInactive(ctx: ChannelHandlerContext) {
super.channelInactive(ctx)
connections.decrementAndGet()
}
}
class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings: ClientSettings, private val stats: AtomicReference<Statistics>) : ServerConfig { class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings: ClientSettings, private val stats: AtomicReference<Statistics>) : ServerConfig {
override fun toServer(httpHandler: HttpHandler): Http4kServer = object : Http4kServer { override fun toServer(httpHandler: HttpHandler): Http4kServer = object : Http4kServer {
private val masterGroup = NioEventLoopGroup() private val masterGroup = NioEventLoopGroup()
@ -46,9 +98,14 @@ class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings:
counter.resetCumulativeTime() counter.resetCumulativeTime()
} }
} }
private val limiter = ConnectionCounter();
override fun start(): Http4kServer = apply { override fun start(): Http4kServer = apply {
val sslContext = SslContextBuilder.forServer(getPrivateKey(tls.privateKey), getX509Cert(tls.certificate)).build() val (mainCert, chainCert) = getX509Certs(tls.certificate);
val sslContext = SslContextBuilder
.forServer(getPrivateKey(tls.privateKey), mainCert, chainCert)
.protocols("TLSv1.3", "TLSv.1.2", "TLSv.1.1", "TLSv.1.0")
.build()
val bootstrap = ServerBootstrap() val bootstrap = ServerBootstrap()
bootstrap.group(masterGroup, workerGroup) bootstrap.group(masterGroup, workerGroup)
@ -57,6 +114,7 @@ class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings:
public override fun initChannel(ch: SocketChannel) { public override fun initChannel(ch: SocketChannel) {
ch.pipeline().addLast("ssl", OptionalSslHandler(sslContext)) ch.pipeline().addLast("ssl", OptionalSslHandler(sslContext))
ch.pipeline().addLast("limiter", limiter)
ch.pipeline().addLast("codec", HttpServerCodec()) ch.pipeline().addLast("codec", HttpServerCodec())
ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536)) ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536))
@ -98,3 +156,12 @@ class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings:
override fun port(): Int = if (clientSettings.clientPort > 0) clientSettings.clientPort else address.port override fun port(): Int = if (clientSettings.clientPort > 0) clientSettings.clientPort else address.port
} }
} }
fun getX509Certs(certificates: String): Pair<X509Certificate, X509Certificate> {
val targetStream: InputStream = ByteArrayInputStream(certificates.toByteArray())
return (CertificateFactory.getInstance("X509").generateCertificate(targetStream) as X509Certificate) to (CertificateFactory.getInstance("X509").generateCertificate(targetStream) as X509Certificate)
}
fun getPrivateKey(privateKey: String): PrivateKey {
return loadKey(privateKey)!!
}