1
0
Fork 1
mirror of https://gitlab.com/mangadex-pub/mangadex_at_home.git synced 2024-01-19 02:48:37 +00:00

Several microopts

This commit is contained in:
carbotaniuman 2021-01-26 10:15:50 -06:00
parent f553bddff6
commit d410e5ec17
9 changed files with 71 additions and 70 deletions

View file

@ -15,7 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed ### Fixed
- [2021-01-25] Add `privileged: true` to mangadex-at-home service in docker-compose to enable use of IOUring for the dockerized version [@_tde9]. - [2021-01-25] Add `privileged: true` to mangadex-at-home service in docker-compose to enable use of IOUring for the dockerized version [@_tde9].
- [2021-01-26] Make updated config restart the webserver and apply changes[@carbotaniuman]. - [2021-01-26] Make updated config restart the webserver and apply changes [@carbotaniuman].
- [2021-01-26] Optimize some code to reduce allocations [@carbotaniuman].
### Security ### Security

View file

@ -196,9 +196,6 @@ class MangaDexClient(private val settingsFile: File, databaseFile: File, cacheFo
if (it.port in Constants.RESTRICTED_PORTS) { if (it.port in Constants.RESTRICTED_PORTS) {
throw ClientSettingsException("Config Error: Unsafe port number") throw ClientSettingsException("Config Error: Unsafe port number")
} }
if (it.threads < 4) {
throw ClientSettingsException("Config Error: Invalid number of threads, must be >= 4")
}
if (it.maxMebibytesPerHour < 0) { if (it.maxMebibytesPerHour < 0) {
throw ClientSettingsException("Config Error: Max bandwidth must be >= 0") throw ClientSettingsException("Config Error: Max bandwidth must be >= 0")
} }

View file

@ -33,7 +33,6 @@ import org.slf4j.LoggerFactory
import java.util.concurrent.CountDownLatch import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicReference
sealed class State sealed class State
@ -61,9 +60,7 @@ class ServerManager(
// this must remain single-threaded because of how the state mechanism works // this must remain single-threaded because of how the state mechanism works
private val executor = Executors.newSingleThreadScheduledExecutor() private val executor = Executors.newSingleThreadScheduledExecutor()
private val registry = PrometheusMeterRegistry(PrometheusConfig.DEFAULT) private val registry = PrometheusMeterRegistry(PrometheusConfig.DEFAULT)
private val statistics: AtomicReference<Statistics> = AtomicReference( private val statistics = Statistics()
Statistics()
)
// state that must only be accessed from the thread on the executor // state that must only be accessed from the thread on the executor
private var state: State private var state: State
@ -80,11 +77,11 @@ class ServerManager(
DefaultMicrometerMetrics(registry, storage.cacheDirectory) DefaultMicrometerMetrics(registry, storage.cacheDirectory)
loginAndStartServer() loginAndStartServer()
var lastBytesSent = statistics.get().bytesSent var lastBytesSent = statistics.bytesSent.get()
executor.scheduleAtFixedRate( executor.scheduleAtFixedRate(
{ {
try { try {
lastBytesSent = statistics.get().bytesSent lastBytesSent = statistics.bytesSent.get()
val state = this.state val state = this.state
if (state is GracefulStop && state.nextState != Shutdown) { if (state is GracefulStop && state.nextState != Shutdown) {
@ -157,7 +154,7 @@ class ServerManager(
try { try {
val state = this.state val state = this.state
if (state is Running) { if (state is Running) {
val currentBytesSent = statistics.get().bytesSent - lastBytesSent val currentBytesSent = statistics.bytesSent.get() - lastBytesSent
if (settings.serverSettings.maxMebibytesPerHour != 0L && settings.serverSettings.maxMebibytesPerHour * 1024 * 1024 /* MiB to bytes */ < currentBytesSent) { if (settings.serverSettings.maxMebibytesPerHour != 0L && settings.serverSettings.maxMebibytesPerHour * 1024 * 1024 /* MiB to bytes */ < currentBytesSent) {
LOGGER.info { "Stopping image server as hourly bandwidth limit reached" } LOGGER.info { "Stopping image server as hourly bandwidth limit reached" }
@ -220,8 +217,8 @@ class ServerManager(
storage, storage,
remoteSettings, remoteSettings,
settings.serverSettings, settings.serverSettings,
statistics,
settings.metricsSettings, settings.metricsSettings,
statistics,
registry registry
).start() ).start()

View file

@ -165,6 +165,8 @@ class ImageStorage(
* @return the [Image] associated with the id or null. * @return the [Image] associated with the id or null.
*/ */
fun loadImage(id: String): Image? { fun loadImage(id: String): Image? {
LOGGER.trace { "Loading image $id from cache" }
return try { return try {
// this try catch handles the case where the image has been deleted // this try catch handles the case where the image has been deleted
// we assume total control over the directory, so this file open // we assume total control over the directory, so this file open
@ -197,6 +199,8 @@ class ImageStorage(
* @return the [Writer] associated with the id or null. * @return the [Writer] associated with the id or null.
*/ */
fun storeImage(id: String, metadata: ImageMetadata): Writer? { fun storeImage(id: String, metadata: ImageMetadata): Writer? {
LOGGER.trace { "Storing image $id into cache" }
if (id.length < 3) { if (id.length < 3) {
throw IllegalArgumentException("id length needs to be at least 3") throw IllegalArgumentException("id length needs to be at least 3")
} }
@ -210,6 +214,8 @@ class ImageStorage(
} }
private fun deleteImage(id: String) { private fun deleteImage(id: String) {
LOGGER.trace { "Deleting image $id from cache" }
database.useTransaction { database.useTransaction {
val path = getTempPath() val path = getTempPath()
@ -222,7 +228,6 @@ class ImageStorage(
Files.deleteIfExists(path) Files.deleteIfExists(path)
} catch (e: IOException) { } catch (e: IOException) {
LOGGER.trace(e) { "Deleting image failed, ignoring" }
// a failure means the image did not exist // a failure means the image did not exist
} finally { } finally {
database.delete(DbImage) { database.delete(DbImage) {
@ -354,11 +359,10 @@ class ImageStorage(
companion object { companion object {
private val LOGGER = LoggerFactory.getLogger(ImageStorage::class.java) private val LOGGER = LoggerFactory.getLogger(ImageStorage::class.java)
private val JACKSON: ObjectMapper = jacksonObjectMapper()
private fun String.toCachePath() = private fun String.toCachePath() =
this.substring(0, 3).replace(".(?!$)".toRegex(), "$0 ").split(" ".toRegex()).reversed() this.substring(0, 3).replace(".(?!$)".toRegex(), "$0 ").split(" ".toRegex()).reversed()
.plus(this).joinToString(File.separator) .plus(this).joinToString(File.separator)
private val JACKSON: ObjectMapper = jacksonObjectMapper()
} }
} }

View file

@ -18,10 +18,8 @@ along with this MangaDex@Home. If not, see <http://www.gnu.org/licenses/>.
*/ */
package mdnet.data package mdnet.data
import com.fasterxml.jackson.databind.PropertyNamingStrategies import java.util.concurrent.atomic.AtomicLong
import com.fasterxml.jackson.databind.annotation.JsonNaming
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy::class) class Statistics(
data class Statistics( val bytesSent: AtomicLong = AtomicLong(0),
val bytesSent: Long = 0,
) )

View file

@ -27,7 +27,9 @@ import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.DecoderException import io.netty.handler.codec.DecoderException
import io.netty.handler.codec.http.* import io.netty.handler.codec.http.HttpObjectAggregator
import io.netty.handler.codec.http.HttpServerCodec
import io.netty.handler.codec.http.HttpServerKeepAliveHandler
import io.netty.handler.ssl.SslContextBuilder import io.netty.handler.ssl.SslContextBuilder
import io.netty.handler.stream.ChunkedWriteHandler import io.netty.handler.stream.ChunkedWriteHandler
import io.netty.handler.timeout.ReadTimeoutException import io.netty.handler.timeout.ReadTimeoutException
@ -39,11 +41,12 @@ import io.netty.handler.traffic.TrafficCounter
import io.netty.incubator.channel.uring.IOUring import io.netty.incubator.channel.uring.IOUring
import io.netty.incubator.channel.uring.IOUringEventLoopGroup import io.netty.incubator.channel.uring.IOUringEventLoopGroup
import io.netty.incubator.channel.uring.IOUringServerSocketChannel import io.netty.incubator.channel.uring.IOUringServerSocketChannel
import io.netty.util.concurrent.DefaultEventExecutorGroup import io.netty.util.internal.SystemPropertyUtil
import mdnet.Constants import mdnet.Constants
import mdnet.data.Statistics import mdnet.data.Statistics
import mdnet.logging.info import mdnet.logging.info
import mdnet.logging.trace import mdnet.logging.trace
import mdnet.logging.warn
import mdnet.settings.ServerSettings import mdnet.settings.ServerSettings
import mdnet.settings.TlsCert import mdnet.settings.TlsCert
import org.http4k.core.HttpHandler import org.http4k.core.HttpHandler
@ -59,33 +62,33 @@ import java.net.SocketException
import java.security.PrivateKey import java.security.PrivateKey
import java.security.cert.CertificateFactory import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate import java.security.cert.X509Certificate
import java.util.concurrent.atomic.AtomicReference import java.util.Locale
import javax.net.ssl.SSLException import javax.net.ssl.SSLException
interface NettyTransport { interface NettyTransport {
val masterGroup: EventLoopGroup val bossGroup: EventLoopGroup
val workerGroup: EventLoopGroup val workerGroup: EventLoopGroup
val factory: ChannelFactory<ServerChannel> val factory: ChannelFactory<ServerChannel>
fun shutdownGracefully() { fun shutdownGracefully() {
masterGroup.shutdownGracefully() bossGroup.shutdownGracefully()
workerGroup.shutdownGracefully() workerGroup.shutdownGracefully()
} }
private class NioTransport : NettyTransport { private class NioTransport : NettyTransport {
override val masterGroup = NioEventLoopGroup() override val bossGroup = NioEventLoopGroup(1)
override val workerGroup = NioEventLoopGroup() override val workerGroup = NioEventLoopGroup()
override val factory = ChannelFactory<ServerChannel> { NioServerSocketChannel() } override val factory = ChannelFactory<ServerChannel> { NioServerSocketChannel() }
} }
private class EpollTransport : NettyTransport { private class EpollTransport : NettyTransport {
override val masterGroup = EpollEventLoopGroup() override val bossGroup = EpollEventLoopGroup(1)
override val workerGroup = EpollEventLoopGroup() override val workerGroup = EpollEventLoopGroup()
override val factory = ChannelFactory<ServerChannel> { EpollServerSocketChannel() } override val factory = ChannelFactory<ServerChannel> { EpollServerSocketChannel() }
} }
private class IOUringTransport : NettyTransport { private class IOUringTransport : NettyTransport {
override val masterGroup = IOUringEventLoopGroup() override val bossGroup = IOUringEventLoopGroup(1)
override val workerGroup = IOUringEventLoopGroup() override val workerGroup = IOUringEventLoopGroup()
override val factory = ChannelFactory<ServerChannel> { IOUringServerSocketChannel() } override val factory = ChannelFactory<ServerChannel> { IOUringServerSocketChannel() }
} }
@ -94,21 +97,24 @@ interface NettyTransport {
private val LOGGER = LoggerFactory.getLogger(NettyTransport::class.java) private val LOGGER = LoggerFactory.getLogger(NettyTransport::class.java)
fun bestForPlatform(): NettyTransport { fun bestForPlatform(): NettyTransport {
if (IOUring.isAvailable()) { val name = SystemPropertyUtil.get("os.name").toLowerCase(Locale.UK).trim { it <= ' ' }
LOGGER.info("Using IOUring transport") if (name.startsWith("linux")) {
return IOUringTransport() if (IOUring.isAvailable()) {
} else { LOGGER.info("Using IOUring transport")
LOGGER.info(IOUring.unavailabilityCause()) { return IOUringTransport()
"IOUring transport not available" } else {
LOGGER.info(IOUring.unavailabilityCause()) {
"IOUring transport not available"
}
} }
}
if (Epoll.isAvailable()) { if (Epoll.isAvailable()) {
LOGGER.info("Using Epoll transport") LOGGER.info("Using Epoll transport")
return EpollTransport() return EpollTransport()
} else { } else {
LOGGER.info(Epoll.unavailabilityCause()) { LOGGER.info(Epoll.unavailabilityCause()) {
"Epoll transport not available" "Epoll transport not available"
}
} }
} }
@ -118,10 +124,13 @@ interface NettyTransport {
} }
} }
class Netty(private val tls: TlsCert, private val serverSettings: ServerSettings, private val statistics: AtomicReference<Statistics>) : ServerConfig { class Netty(
private val tls: TlsCert,
private val serverSettings: ServerSettings,
private val statistics: Statistics
) : ServerConfig {
override fun toServer(httpHandler: HttpHandler): Http4kServer = object : Http4kServer { override fun toServer(httpHandler: HttpHandler): Http4kServer = object : Http4kServer {
private val transport = NettyTransport.bestForPlatform() private val transport = NettyTransport.bestForPlatform()
private val executor = DefaultEventExecutorGroup(serverSettings.threads)
private lateinit var closeFuture: ChannelFuture private lateinit var closeFuture: ChannelFuture
private lateinit var address: InetSocketAddress private lateinit var address: InetSocketAddress
@ -130,15 +139,13 @@ class Netty(private val tls: TlsCert, private val serverSettings: ServerSettings
transport.workerGroup, serverSettings.maxKilobitsPerSecond * 1000L / 8L, 0, 50 transport.workerGroup, serverSettings.maxKilobitsPerSecond * 1000L / 8L, 0, 50
) { ) {
override fun doAccounting(counter: TrafficCounter) { override fun doAccounting(counter: TrafficCounter) {
statistics.getAndUpdate { statistics.bytesSent.getAndAccumulate(counter.cumulativeWrittenBytes()) { a, b -> a + b }
it.copy(bytesSent = it.bytesSent + counter.cumulativeWrittenBytes())
}
counter.resetCumulativeTime() counter.resetCumulativeTime()
} }
} }
override fun start(): Http4kServer = apply { override fun start(): Http4kServer = apply {
LOGGER.info { "Starting Netty with ${serverSettings.threads} threads" } LOGGER.info { "Starting Netty!" }
val certs = getX509Certs(tls.certificate) val certs = getX509Certs(tls.certificate)
val sslContext = SslContextBuilder val sslContext = SslContextBuilder
@ -147,7 +154,7 @@ class Netty(private val tls: TlsCert, private val serverSettings: ServerSettings
.build() .build()
val bootstrap = ServerBootstrap() val bootstrap = ServerBootstrap()
bootstrap.group(transport.masterGroup, transport.workerGroup) bootstrap.group(transport.bossGroup, transport.workerGroup)
.channelFactory(transport.factory) .channelFactory(transport.factory)
.childHandler(object : ChannelInitializer<SocketChannel>() { .childHandler(object : ChannelInitializer<SocketChannel>() {
public override fun initChannel(ch: SocketChannel) { public override fun initChannel(ch: SocketChannel) {
@ -159,24 +166,28 @@ class Netty(private val tls: TlsCert, private val serverSettings: ServerSettings
ch.pipeline().addLast("burstLimiter", burstLimiter) ch.pipeline().addLast("burstLimiter", burstLimiter)
ch.pipeline().addLast("readTimeoutHandler", ReadTimeoutHandler(Constants.MAX_READ_TIME_SECONDS)) ch.pipeline().addLast(
ch.pipeline().addLast("writeTimeoutHandler", WriteTimeoutHandler(Constants.MAX_WRITE_TIME_SECONDS)) "readTimeoutHandler",
ReadTimeoutHandler(Constants.MAX_READ_TIME_SECONDS)
)
ch.pipeline().addLast(
"writeTimeoutHandler",
WriteTimeoutHandler(Constants.MAX_WRITE_TIME_SECONDS)
)
ch.pipeline().addLast("streamer", ChunkedWriteHandler()) ch.pipeline().addLast("streamer", ChunkedWriteHandler())
ch.pipeline().addLast(executor, "handler", Http4kChannelHandler(httpHandler)) ch.pipeline().addLast("handler", Http4kChannelHandler(httpHandler))
ch.pipeline().addLast( ch.pipeline().addLast(
"exceptions", "exceptions",
object : ChannelInboundHandlerAdapter() { object : ChannelInboundHandlerAdapter() {
override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) { override fun exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable) {
if (cause is SSLException || (cause is DecoderException && cause.cause is SSLException)) { if (cause is SSLException || (cause is DecoderException && cause.cause is SSLException)) {
LOGGER.trace { "Ignored invalid SSL connection" } LOGGER.trace(cause) { "Ignored invalid SSL connection" }
LOGGER.trace(cause) { "Exception in pipeline" }
} else if (cause is IOException || cause is SocketException) { } else if (cause is IOException || cause is SocketException) {
LOGGER.info { "User (downloader) abruptly closed the connection" } LOGGER.trace(cause) { "User (downloader) abruptly closed the connection" }
LOGGER.trace(cause) { "Exception in pipeline" }
} else if (cause !is ReadTimeoutException && cause !is WriteTimeoutException) { } else if (cause !is ReadTimeoutException && cause !is WriteTimeoutException) {
ctx.fireExceptionCaught(cause) LOGGER.warn(cause) { "Exception in pipeline" }
} }
} }
} }
@ -194,10 +205,9 @@ class Netty(private val tls: TlsCert, private val serverSettings: ServerSettings
override fun stop() = apply { override fun stop() = apply {
closeFuture.cancel(false) closeFuture.cancel(false)
transport.shutdownGracefully() transport.shutdownGracefully()
executor.shutdownGracefully()
} }
override fun port(): Int = if (serverSettings.port > 0) serverSettings.port else address.port override fun port(): Int = serverSettings.port
} }
companion object { companion object {

View file

@ -69,7 +69,6 @@ import java.time.Clock
import java.time.OffsetDateTime import java.time.OffsetDateTime
import java.util.* import java.util.*
import java.util.concurrent.Executors import java.util.concurrent.Executors
import java.util.concurrent.atomic.AtomicReference
private val LOGGER = LoggerFactory.getLogger(ImageServer::class.java) private val LOGGER = LoggerFactory.getLogger(ImageServer::class.java)
private val JACKSON: ObjectMapper = jacksonObjectMapper() private val JACKSON: ObjectMapper = jacksonObjectMapper()
@ -125,7 +124,7 @@ class ImageServer(
Response(Status.NOT_MODIFIED) Response(Status.NOT_MODIFIED)
.header("Last-Modified", lastModified) .header("Last-Modified", lastModified)
} else { } else {
LOGGER.info { "Request for $sanitizedUri hit cache" } LOGGER.info { "Request for $sanitizedUri is being served" }
respondWithImage( respondWithImage(
BufferedInputStream(image.stream), BufferedInputStream(image.stream),
@ -136,8 +135,6 @@ class ImageServer(
} }
private fun Request.handleCacheMiss(sanitizedUri: String, imageId: String): Response { private fun Request.handleCacheMiss(sanitizedUri: String, imageId: String): Response {
LOGGER.info { "Request for $sanitizedUri missed cache" }
val mdResponse = client(Request(Method.GET, sanitizedUri)) val mdResponse = client(Request(Method.GET, sanitizedUri))
if (mdResponse.status != Status.OK) { if (mdResponse.status != Status.OK) {
@ -234,8 +231,8 @@ fun getServer(
storage: ImageStorage, storage: ImageStorage,
remoteSettings: RemoteSettings, remoteSettings: RemoteSettings,
serverSettings: ServerSettings, serverSettings: ServerSettings,
statistics: AtomicReference<Statistics>,
metricsSettings: MetricsSettings, metricsSettings: MetricsSettings,
statistics: Statistics,
registry: PrometheusMeterRegistry, registry: PrometheusMeterRegistry,
): Http4kServer { ): Http4kServer {
val apache = ApacheClient( val apache = ApacheClient(
@ -261,8 +258,6 @@ fun getServer(
val client = val client =
ClientFilters.SetBaseUriFrom(remoteSettings.imageServer) ClientFilters.SetBaseUriFrom(remoteSettings.imageServer)
.then(ClientFilters.MicrometerMetrics.RequestCounter(registry))
.then(ClientFilters.MicrometerMetrics.RequestTimer(registry))
.then(apache) .then(apache)
val imageServer = ImageServer( val imageServer = ImageServer(
@ -274,7 +269,7 @@ fun getServer(
FunctionCounter.builder( FunctionCounter.builder(
"client_sent_bytes", "client_sent_bytes",
statistics, statistics,
{ it.get().bytesSent.toDouble() } { it.bytesSent.get().toDouble() }
).register(registry) ).register(registry)
val verifier = tokenVerifier( val verifier = tokenVerifier(

View file

@ -41,7 +41,6 @@ data class ServerSettings(
val externalMaxKilobitsPerSecond: Long = 0, val externalMaxKilobitsPerSecond: Long = 0,
val maxMebibytesPerHour: Long = 0, val maxMebibytesPerHour: Long = 0,
val port: Int = 443, val port: Int = 443,
val threads: Int = 4,
) )
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy::class) @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy::class)

View file

@ -1,7 +1,7 @@
<configuration> <configuration>
<appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender"> <appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
<filter class="ch.qos.logback.classic.filter.ThresholdFilter"> <filter class="ch.qos.logback.classic.filter.ThresholdFilter">
<level>${file-level:-WARN}</level> <level>${file-level:-${root-level:-WARN}}}</level>
</filter> </filter>
<file>log/latest.log</file> <file>log/latest.log</file>
@ -20,7 +20,7 @@
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<filter class="ch.qos.logback.classic.filter.ThresholdFilter"> <filter class="ch.qos.logback.classic.filter.ThresholdFilter">
<level>${stdout-level:-INFO}</level> <level>${stdout-level:-${root-level:-INFO}}</level>
</filter> </filter>
<encoder> <encoder>
@ -29,7 +29,7 @@
</encoder> </encoder>
</appender> </appender>
<root level="TRACE"> <root level="${root-level:-INFO}">
<appender-ref ref="STDOUT"/> <appender-ref ref="STDOUT"/>
<appender-ref ref="FILE"/> <appender-ref ref="FILE"/>
</root> </root>