]> Sergey Matveev's repositories - godlighty.git/blobdiff - tls.go
.xht
[godlighty.git] / tls.go
diff --git a/tls.go b/tls.go
index 9ac600c364e25a4ed1a9f853d784d3fe9f9a28c4..0451b0328f27b66dd1150155174567ea359d711d 100644 (file)
--- a/tls.go
+++ b/tls.go
@@ -1,6 +1,6 @@
 /*
 godlighty -- highly-customizable HTTP, HTTP/2, HTTPS server
-Copyright (C) 2021 Sergey Matveev <stargrave@stargrave.org>
+Copyright (C) 2021-2023 Sergey Matveev <stargrave@stargrave.org>
 
 This program is free software: you can redistribute it and/or modify
 it under the terms of the GNU General Public License as published by
@@ -19,47 +19,160 @@ package godlighty
 
 import (
        "crypto/tls"
+       "crypto/x509"
        "encoding/pem"
        "errors"
        "fmt"
-       "io/ioutil"
        "log"
+       "os"
 )
 
-var HostToCertificate map[string]*tls.Certificate
+var (
+       NextProtos = []string{"h2", "http/1.1"}
+
+       HostToECDSACertificate map[string]*tls.Certificate
+       HostECDSAClientAuth    map[string]*x509.CertPool
+
+       HostToEdDSACertificate map[string]*tls.Certificate
+       HostEdDSAClientAuth    map[string]*x509.CertPool
+
+       HostToGOSTCertificate map[string]*tls.Certificate
+       HostGOSTClientAuth    map[string]*x509.CertPool
+)
+
+func CHIHasTLS13(chi *tls.ClientHelloInfo) bool {
+       for _, v := range chi.SupportedVersions {
+               if v == tls.VersionTLS13 {
+                       return true
+               }
+       }
+       return false
+}
+
+func CHIHasEdDSA(chi *tls.ClientHelloInfo) bool {
+       if !CHIHasTLS13(chi) {
+               return false
+       }
+       for _, ss := range chi.SignatureSchemes {
+               if ss == tls.Ed25519 {
+                       return true
+               }
+       }
+       return false
+}
 
 func GetCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
-       cert := HostToCertificate[chi.ServerName]
+       if CHIHasGOST(chi) {
+               if cert := HostToGOSTCertificate[chi.ServerName]; cert != nil {
+                       return cert, nil
+               }
+       }
+       if CHIHasEdDSA(chi) {
+               if cert := HostToEdDSACertificate[chi.ServerName]; cert != nil {
+                       return cert, nil
+               }
+       }
+       cert := HostToECDSACertificate[chi.ServerName]
        if cert == nil {
                return nil, errors.New("no certificate found")
        }
        return cert, nil
 }
 
-func LoadCertificates() {
-       HostToCertificate = make(map[string]*tls.Certificate, len(Hosts))
-       for host, cfg := range Hosts {
-               if cfg.TLS == nil {
-                       continue
+func GetConfigForClient(chi *tls.ClientHelloInfo) (*tls.Config, error) {
+       var pool *x509.CertPool
+       if CHIHasGOST(chi) {
+               pool = HostGOSTClientAuth[chi.ServerName]
+       }
+       if pool == nil && CHIHasEdDSA(chi) {
+               pool = HostEdDSAClientAuth[chi.ServerName]
+       }
+       if pool == nil {
+               pool = HostECDSAClientAuth[chi.ServerName]
+       }
+       if pool == nil {
+               return nil, nil
+       }
+       return &tls.Config{
+               GetCertificate: GetCertificate,
+               NextProtos:     NextProtos,
+               ClientCAs:      pool,
+               ClientAuth:     tls.RequireAndVerifyClientCert,
+       }, nil
+}
+
+func loadCertificates(
+       host string,
+       cfg *TLSCfg,
+       hostToCertificate *map[string]*tls.Certificate,
+       hostClientAuth *map[string]*x509.CertPool,
+) {
+       if cfg == nil {
+               return
+       }
+       cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
+       if err != nil {
+               log.Fatalln(err)
+       }
+       if cfg.CACert != "" {
+               data, err := os.ReadFile(cfg.CACert)
+               if err != nil {
+                       log.Fatalln(err)
                }
-               cert, err := tls.LoadX509KeyPair(cfg.TLS.Cert, cfg.TLS.Key)
+               block, _ := pem.Decode(data)
+               if block == nil {
+                       log.Fatalln(fmt.Errorf("no PEM found: %s", cfg.CACert))
+               }
+               if block.Type != "CERTIFICATE" {
+                       log.Fatalln(fmt.Errorf("non CERTIFICATE: %s", cfg.CACert))
+               }
+               cert.Certificate = append(cert.Certificate, block.Bytes)
+       }
+       (*hostToCertificate)[host] = &cert
+       pool := x509.NewCertPool()
+       for _, p := range cfg.ClientCAs {
+               data, err := os.ReadFile(p)
                if err != nil {
                        log.Fatalln(err)
                }
-               if cfg.TLS.CACert != "" {
-                       data, err := ioutil.ReadFile(cfg.TLS.CACert)
-                       if err != nil {
-                               log.Fatalln(err)
-                       }
-                       block, _ := pem.Decode(data)
+               var block *pem.Block
+               for len(data) > 0 {
+                       block, data = pem.Decode(data)
                        if block == nil {
-                               log.Fatalln(fmt.Errorf("no PEM found: %s", cfg.TLS.CACert))
+                               log.Fatalln("can not decode PEM:", p)
                        }
                        if block.Type != "CERTIFICATE" {
-                               log.Fatalln(fmt.Errorf("non CERTIFICATE: %s", cfg.TLS.CACert))
+                               continue
                        }
-                       cert.Certificate = append(cert.Certificate, block.Bytes)
+                       ca, err := x509.ParseCertificate(block.Bytes)
+                       if err != nil {
+                               log.Fatalln(err)
+                       }
+                       pool.AddCert(ca)
+                       (*hostClientAuth)[host] = pool
                }
-               HostToCertificate[host] = &cert
+       }
+}
+
+func LoadCertificates() {
+       HostToECDSACertificate = make(map[string]*tls.Certificate, len(Hosts))
+       HostECDSAClientAuth = make(map[string]*x509.CertPool)
+       HostToEdDSACertificate = make(map[string]*tls.Certificate, len(Hosts))
+       HostEdDSAClientAuth = make(map[string]*x509.CertPool)
+       HostToGOSTCertificate = make(map[string]*tls.Certificate, len(Hosts))
+       HostGOSTClientAuth = make(map[string]*x509.CertPool)
+       for host, cfg := range Hosts {
+               loadCertificates(host, cfg.ECDSATLS, &HostToECDSACertificate, &HostECDSAClientAuth)
+               loadCertificates(host, cfg.EdDSATLS, &HostToEdDSACertificate, &HostEdDSAClientAuth)
+               loadCertificates(host, cfg.GOSTTLS, &HostToGOSTCertificate, &HostGOSTClientAuth)
+       }
+}
+
+func NewTLSConfig() *tls.Config {
+       return &tls.Config{
+               MinVersion:         tls.VersionTLS12,
+               NextProtos:         NextProtos,
+               GetCertificate:     GetCertificate,
+               GetConfigForClient: GetConfigForClient,
        }
 }