diff --git a/build.gradle b/build.gradle index 2ed15ba..9c40348 100644 --- a/build.gradle +++ b/build.gradle @@ -7,7 +7,7 @@ plugins { } group = "com.mangadex" -version = "1.0.0-rc8" +version = "1.0.0-rc9" mainClassName = "mdnet.base.MangaDexClient" repositories { diff --git a/src/main/java/mdnet/base/Constants.java b/src/main/java/mdnet/base/Constants.java index d39a2f6..2d4633a 100644 --- a/src/main/java/mdnet/base/Constants.java +++ b/src/main/java/mdnet/base/Constants.java @@ -6,4 +6,7 @@ public class Constants { public static final int CLIENT_BUILD = 2; public static final String CLIENT_VERSION = "1.0"; public static final Duration MAX_AGE_CACHE = Duration.ofDays(14); + + public static final int MAX_CONCURRENT_CONNECTIONS = 2; + public static final String OVERLOADED_MESSAGE = "This server is experiencing a surge in connections. Please try again later."; } diff --git a/src/main/kotlin/mdnet/base/Application.kt b/src/main/kotlin/mdnet/base/Application.kt index 0020f21..ed2aac0 100644 --- a/src/main/kotlin/mdnet/base/Application.kt +++ b/src/main/kotlin/mdnet/base/Application.kt @@ -23,6 +23,8 @@ import org.http4k.routing.routes import org.http4k.server.Http4kServer import org.http4k.server.asServer import org.slf4j.LoggerFactory +import java.io.BufferedInputStream +import java.io.BufferedOutputStream import java.io.InputStream import java.security.MessageDigest import java.time.ZoneOffset @@ -37,10 +39,15 @@ import javax.crypto.CipherOutputStream import javax.crypto.spec.SecretKeySpec private val LOGGER = LoggerFactory.getLogger("Application") +private val THREADS_TO_ALLOCATE = Runtime.getRuntime().availableProcessors() * 30 / 2 ; fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSettings: ClientSettings, statistics: AtomicReference): Http4kServer { val executor = Executors.newCachedThreadPool() + if (LOGGER.isInfoEnabled) { + LOGGER.info("Starting ApacheClient with {} threads", THREADS_TO_ALLOCATE) + } + val client = ApacheClient(responseBodyMode = BodyMode.Stream, client = HttpClients.custom() .setDefaultRequestConfig(RequestConfig.custom() .setCookieSpec(CookieSpecs.IGNORE_COOKIES) @@ -48,8 +55,8 @@ fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSetting .setSocketTimeout(3000) .setConnectionRequestTimeout(3000) .build()) - .setMaxConnTotal(75) - .setMaxConnPerRoute(75) + .setMaxConnTotal(THREADS_TO_ALLOCATE) + .setMaxConnPerRoute(THREADS_TO_ALLOCATE) .build()) val app = { dataSaver: Boolean -> @@ -122,7 +129,7 @@ fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSetting } respondWithImage( - CipherInputStream(snapshot.getInputStream(0), getRc4(rc4Bytes)), + CipherInputStream(BufferedInputStream(snapshot.getInputStream(0)), getRc4(rc4Bytes)), snapshot.getLength(0).toString(), snapshot.getString(1), snapshot.getString(2) ) } @@ -161,19 +168,19 @@ fun getServer(cache: DiskLruCache, serverSettings: ServerSettings, clientSetting val tee = CachingInputStream( mdResponse.body.stream, - executor, CipherOutputStream(editor.newOutputStream(0), getRc4(rc4Bytes)) + executor, CipherOutputStream(BufferedOutputStream(editor.newOutputStream(0)), getRc4(rc4Bytes)) ) { // Note: if neither of the options get called/are in the log // check that tee gets closed and for exceptions in this lambda if (editor.getLength(0) == contentLength.toLong()) { if (LOGGER.isInfoEnabled) { - LOGGER.info("Cache download $sanitizedUri committed") + LOGGER.info("Cache download for $sanitizedUri committed") } editor.commit() } else { if (LOGGER.isInfoEnabled) { - LOGGER.info("Cache download $sanitizedUri aborted") + LOGGER.info("Cache download for $sanitizedUri aborted") } editor.abort() diff --git a/src/main/kotlin/mdnet/base/Keys.kt b/src/main/kotlin/mdnet/base/Keys.kt index 02a0767..0216796 100644 --- a/src/main/kotlin/mdnet/base/Keys.kt +++ b/src/main/kotlin/mdnet/base/Keys.kt @@ -36,15 +36,6 @@ private const val PKCS_1_PEM_FOOTER = "-----END RSA PRIVATE KEY-----" private const val PKCS_8_PEM_HEADER = "-----BEGIN PRIVATE KEY-----" private const val PKCS_8_PEM_FOOTER = "-----END PRIVATE KEY-----" -fun getX509Cert(certificate: String): X509Certificate { - val targetStream: InputStream = ByteArrayInputStream(certificate.toByteArray()) - return CertificateFactory.getInstance("X509").generateCertificate(targetStream) as X509Certificate -} - -fun getPrivateKey(privateKey: String): PrivateKey { - return loadKey(privateKey)!! -} - fun loadKey(keyDataString: String): PrivateKey? { if (keyDataString.contains(PKCS_1_PEM_HEADER)) { // OpenSSL / PKCS#1 Base64 PEM encoded file diff --git a/src/main/kotlin/mdnet/base/Netty.kt b/src/main/kotlin/mdnet/base/Netty.kt index 7cfdd42..11596e4 100644 --- a/src/main/kotlin/mdnet/base/Netty.kt +++ b/src/main/kotlin/mdnet/base/Netty.kt @@ -1,8 +1,10 @@ package mdnet.base import io.netty.bootstrap.ServerBootstrap +import io.netty.buffer.Unpooled import io.netty.channel.ChannelFactory import io.netty.channel.ChannelFuture +import io.netty.channel.ChannelHandler import io.netty.channel.ChannelHandlerContext import io.netty.channel.ChannelInboundHandlerAdapter import io.netty.channel.ChannelInitializer @@ -12,10 +14,16 @@ import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.handler.codec.DecoderException +import io.netty.handler.codec.http.DefaultFullHttpResponse +import io.netty.handler.codec.http.HttpHeaderNames import io.netty.handler.codec.http.HttpObjectAggregator +import io.netty.handler.codec.http.HttpResponseStatus import io.netty.handler.codec.http.HttpServerCodec +import io.netty.handler.codec.http.HttpUtil +import io.netty.handler.codec.http.HttpVersion import io.netty.handler.ssl.OptionalSslHandler import io.netty.handler.ssl.SslContextBuilder +import io.netty.handler.ssl.SslHandler import io.netty.handler.stream.ChunkedWriteHandler import io.netty.handler.traffic.GlobalTrafficShapingHandler import io.netty.handler.traffic.TrafficCounter @@ -24,14 +32,58 @@ import org.http4k.server.Http4kChannelHandler import org.http4k.server.Http4kServer import org.http4k.server.ServerConfig import org.slf4j.LoggerFactory +import java.io.ByteArrayInputStream import java.io.IOException +import java.io.InputStream import java.net.InetSocketAddress +import java.nio.charset.StandardCharsets +import java.security.PrivateKey +import java.security.cert.CertificateFactory +import java.security.cert.X509Certificate import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicReference import javax.net.ssl.SSLException private val LOGGER = LoggerFactory.getLogger("Application") +@ChannelHandler.Sharable +class ConnectionCounter : ChannelInboundHandlerAdapter() { + private val connections = AtomicInteger() + + override fun channelActive(ctx: ChannelHandlerContext) { + val sslHandler = ctx.pipeline()[SslHandler::class.java] + + if (sslHandler != null) { + sslHandler.handshakeFuture().addListener { + handleConnection(ctx) + } + } else { + handleConnection(ctx) + } + } + + private fun handleConnection(ctx: ChannelHandlerContext) { + if (connections.incrementAndGet() <= Constants.MAX_CONCURRENT_CONNECTIONS) { + super.channelActive(ctx) + } else { + val response = Unpooled.copiedBuffer(Constants.OVERLOADED_MESSAGE, StandardCharsets.UTF_8) + val res = + DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SERVICE_UNAVAILABLE, response) + res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/html; charset=UTF-8") + HttpUtil.setContentLength(res, response.readableBytes().toLong()) + + ctx.writeAndFlush(res) + ctx.close() + } + } + + override fun channelInactive(ctx: ChannelHandlerContext) { + super.channelInactive(ctx) + connections.decrementAndGet() + } +} + class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings: ClientSettings, private val stats: AtomicReference) : ServerConfig { override fun toServer(httpHandler: HttpHandler): Http4kServer = object : Http4kServer { private val masterGroup = NioEventLoopGroup() @@ -46,9 +98,14 @@ class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings: counter.resetCumulativeTime() } } + private val limiter = ConnectionCounter(); override fun start(): Http4kServer = apply { - val sslContext = SslContextBuilder.forServer(getPrivateKey(tls.privateKey), getX509Cert(tls.certificate)).build() + val (mainCert, chainCert) = getX509Certs(tls.certificate); + val sslContext = SslContextBuilder + .forServer(getPrivateKey(tls.privateKey), mainCert, chainCert) + .protocols("TLSv1.3", "TLSv.1.2", "TLSv.1.1", "TLSv.1.0") + .build() val bootstrap = ServerBootstrap() bootstrap.group(masterGroup, workerGroup) @@ -57,6 +114,7 @@ class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings: public override fun initChannel(ch: SocketChannel) { ch.pipeline().addLast("ssl", OptionalSslHandler(sslContext)) + ch.pipeline().addLast("limiter", limiter) ch.pipeline().addLast("codec", HttpServerCodec()) ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536)) @@ -98,3 +156,12 @@ class Netty(private val tls: ServerSettings.TlsCert, private val clientSettings: override fun port(): Int = if (clientSettings.clientPort > 0) clientSettings.clientPort else address.port } } + +fun getX509Certs(certificates: String): Pair { + val targetStream: InputStream = ByteArrayInputStream(certificates.toByteArray()) + return (CertificateFactory.getInstance("X509").generateCertificate(targetStream) as X509Certificate) to (CertificateFactory.getInstance("X509").generateCertificate(targetStream) as X509Certificate) +} + +fun getPrivateKey(privateKey: String): PrivateKey { + return loadKey(privateKey)!! +} \ No newline at end of file