From 6acb43bbbb9b15c938d246837a17dca66393409e Mon Sep 17 00:00:00 2001 From: James Hedley Date: Fri, 13 Oct 2023 17:02:24 +0100 Subject: [PATCH] Add optional flags to enable mTLS with verification of client certificate (#4910) --- weed/command/s3.go | 28 +++++++++++++++++++++++++++- weed/command/server.go | 2 ++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/weed/command/s3.go b/weed/command/s3.go index f2dffd57e..dc943b23d 100644 --- a/weed/command/s3.go +++ b/weed/command/s3.go @@ -3,7 +3,9 @@ package command import ( "context" "crypto/tls" + "crypto/x509" "fmt" + "io/ioutil" "net" "net/http" "os" @@ -42,6 +44,8 @@ type S3Options struct { domainName *string tlsPrivateKey *string tlsCertificate *string + tlsCACertificate *string + tlsVerifyClientCert *bool metricsHttpPort *int allowEmptyFolder *bool allowDeleteBucketNotEmpty *bool @@ -65,6 +69,8 @@ func init() { s3StandaloneOptions.auditLogConfig = cmdS3.Flag.String("auditLogConfig", "", "path to the audit log config file") s3StandaloneOptions.tlsPrivateKey = cmdS3.Flag.String("key.file", "", "path to the TLS private key file") s3StandaloneOptions.tlsCertificate = cmdS3.Flag.String("cert.file", "", "path to the TLS certificate file") + s3StandaloneOptions.tlsCACertificate = cmdS3.Flag.String("cacert.file", "", "path to the TLS CA certificate file") + s3StandaloneOptions.tlsVerifyClientCert = cmdS3.Flag.Bool("tlsVerifyClientCert", false, "whether to verify the client's certificate") s3StandaloneOptions.metricsHttpPort = cmdS3.Flag.Int("metricsPort", 0, "Prometheus metrics listen port") s3StandaloneOptions.allowEmptyFolder = cmdS3.Flag.Bool("allowEmptyFolder", true, "allow empty folders") s3StandaloneOptions.allowDeleteBucketNotEmpty = cmdS3.Flag.Bool("allowDeleteBucketNotEmpty", true, "allow recursive deleting all entries along with bucket") @@ -289,7 +295,27 @@ func (s3opt *S3Options) startS3Server() bool { if s3opt.certProvider, err = pemfile.NewProvider(pemfileOptions); err != nil { glog.Fatalf("pemfile.NewProvider(%v) failed: %v", pemfileOptions, err) } - httpS.TLSConfig = &tls.Config{GetCertificate: s3opt.GetCertificateWithUpdate} + + caCertPool := x509.NewCertPool() + if *s3Options.tlsCACertificate != "" { + // load CA certificate file and add it to list of client CAs + caCertFile, err := ioutil.ReadFile(*s3opt.tlsCACertificate) + if err != nil { + glog.Fatalf("error reading CA certificate: %v", err) + } + caCertPool.AppendCertsFromPEM(caCertFile) + } + + clientAuth := tls.NoClientCert + if *s3Options.tlsVerifyClientCert { + clientAuth = tls.RequireAndVerifyClientCert + } + + httpS.TLSConfig = &tls.Config{ + GetCertificate: s3opt.GetCertificateWithUpdate, + ClientAuth: clientAuth, + ClientCAs: caCertPool, + } if *s3opt.portHttps == 0 { glog.V(0).Infof("Start Seaweed S3 API Server %s at https port %d", util.Version(), *s3opt.port) if s3ApiLocalListener != nil { diff --git a/weed/command/server.go b/weed/command/server.go index 7fbb59676..67e37426e 100644 --- a/weed/command/server.go +++ b/weed/command/server.go @@ -144,6 +144,8 @@ func init() { s3Options.domainName = cmdServer.Flag.String("s3.domainName", "", "suffix of the host name in comma separated list, {bucket}.{domainName}") s3Options.tlsPrivateKey = cmdServer.Flag.String("s3.key.file", "", "path to the TLS private key file") s3Options.tlsCertificate = cmdServer.Flag.String("s3.cert.file", "", "path to the TLS certificate file") + s3Options.tlsCACertificate = cmdServer.Flag.String("s3.cacert.file", "", "path to the TLS CA certificate file") + s3Options.tlsVerifyClientCert = cmdServer.Flag.Bool("s3.tlsVerifyClientCert", false, "whether to verify the client's certificate") s3Options.config = cmdServer.Flag.String("s3.config", "", "path to the config file") s3Options.auditLogConfig = cmdServer.Flag.String("s3.auditLogConfig", "", "path to the audit log config file") s3Options.allowEmptyFolder = cmdServer.Flag.Bool("s3.allowEmptyFolder", true, "allow empty folders")