import (
"crypto/tls"
+ "crypto/x509"
"encoding/pem"
"errors"
"fmt"
"log"
)
-var HostToCertificate map[string]*tls.Certificate
+var (
+ NextProtos = []string{"h2", "http/1.1"}
+ HostToCertificate map[string]*tls.Certificate
+ HostClientAuth map[string]*x509.CertPool
+)
func GetCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
cert := HostToCertificate[chi.ServerName]
return cert, nil
}
+func GetConfigForClient(chi *tls.ClientHelloInfo) (*tls.Config, error) {
+ pool := HostClientAuth[chi.ServerName]
+ if pool == nil {
+ return nil, nil
+ }
+ return &tls.Config{
+ GetCertificate: GetCertificate,
+ NextProtos: NextProtos,
+ ClientCAs: pool,
+ ClientAuth: tls.RequireAndVerifyClientCert,
+ }, nil
+}
+
func LoadCertificates() {
HostToCertificate = make(map[string]*tls.Certificate, len(Hosts))
+ HostClientAuth = make(map[string]*x509.CertPool)
for host, cfg := range Hosts {
if cfg.TLS == nil {
continue
cert.Certificate = append(cert.Certificate, block.Bytes)
}
HostToCertificate[host] = &cert
+ pool := x509.NewCertPool()
+ for _, p := range cfg.TLS.ClientCAs {
+ data, err := ioutil.ReadFile(p)
+ if err != nil {
+ log.Fatalln(err)
+ }
+ var block *pem.Block
+ for len(data) > 0 {
+ block, data = pem.Decode(data)
+ if block == nil {
+ log.Fatalln("can not decode PEM:", p)
+ }
+ if block.Type != "CERTIFICATE" {
+ continue
+ }
+ ca, err := x509.ParseCertificate(block.Bytes)
+ if err != nil {
+ log.Fatalln(err)
+ }
+ pool.AddCert(ca)
+ }
+ }
+ if len(pool.Subjects()) > 0 {
+ HostClientAuth[host] = pool
+ }
+ }
+}
+
+func NewTLSConfig() *tls.Config {
+ return &tls.Config{
+ NextProtos: NextProtos,
+ GetCertificate: GetCertificate,
+ GetConfigForClient: GetConfigForClient,
}
}