]> Sergey Matveev's repositories - godlighty.git/blobdiff - tls.go
Simultaneous ECDSA and EdDSA support
[godlighty.git] / tls.go
diff --git a/tls.go b/tls.go
index 40395f0c590de89b0b4b516a70ba512897c4c4f5..51470f37bb802ebe1f45cb10eefe0c1b2d24e74b 100644 (file)
--- a/tls.go
+++ b/tls.go
@@ -30,20 +30,49 @@ import (
 var (
        NextProtos = []string{"h2", "http/1.1"}
 
-       HostToCertificate map[string]*tls.Certificate
-       HostClientAuth    map[string]*x509.CertPool
+       HostToECDSACertificate map[string]*tls.Certificate
+       HostECDSAClientAuth    map[string]*x509.CertPool
+
+       HostToEdDSACertificate map[string]*tls.Certificate
+       HostEdDSAClientAuth    map[string]*x509.CertPool
 
        HostToGOSTCertificate map[string]*tls.Certificate
        HostGOSTClientAuth    map[string]*x509.CertPool
 )
 
+func CHIHasTLS13(chi *tls.ClientHelloInfo) bool {
+       for _, v := range chi.SupportedVersions {
+               if v == tls.VersionTLS13 {
+                       return true
+               }
+       }
+       return false
+}
+
+func CHIHasEdDSA(chi *tls.ClientHelloInfo) bool {
+       if !CHIHasTLS13(chi) {
+               return false
+       }
+       for _, ss := range chi.SignatureSchemes {
+               if ss == tls.Ed25519 {
+                       return true
+               }
+       }
+       return false
+}
+
 func GetCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
        if CHIHasGOST(chi) {
                if cert := HostToGOSTCertificate[chi.ServerName]; cert != nil {
                        return cert, nil
                }
        }
-       cert := HostToCertificate[chi.ServerName]
+       if CHIHasEdDSA(chi) {
+               if cert := HostToEdDSACertificate[chi.ServerName]; cert != nil {
+                       return cert, nil
+               }
+       }
+       cert := HostToECDSACertificate[chi.ServerName]
        if cert == nil {
                return nil, errors.New("no certificate found")
        }
@@ -55,8 +84,11 @@ func GetConfigForClient(chi *tls.ClientHelloInfo) (*tls.Config, error) {
        if CHIHasGOST(chi) {
                pool = HostGOSTClientAuth[chi.ServerName]
        }
+       if pool == nil && CHIHasEdDSA(chi) {
+               pool = HostEdDSAClientAuth[chi.ServerName]
+       }
        if pool == nil {
-               pool = HostClientAuth[chi.ServerName]
+               pool = HostECDSAClientAuth[chi.ServerName]
        }
        if pool == nil {
                return nil, nil
@@ -125,12 +157,15 @@ func loadCertificates(
 }
 
 func LoadCertificates() {
-       HostToCertificate = make(map[string]*tls.Certificate, len(Hosts))
-       HostClientAuth = make(map[string]*x509.CertPool)
+       HostToECDSACertificate = make(map[string]*tls.Certificate, len(Hosts))
+       HostECDSAClientAuth = make(map[string]*x509.CertPool)
+       HostToEdDSACertificate = make(map[string]*tls.Certificate, len(Hosts))
+       HostEdDSAClientAuth = make(map[string]*x509.CertPool)
        HostToGOSTCertificate = make(map[string]*tls.Certificate, len(Hosts))
        HostGOSTClientAuth = make(map[string]*x509.CertPool)
        for host, cfg := range Hosts {
-               loadCertificates(host, cfg.TLS, &HostToCertificate, &HostClientAuth)
+               loadCertificates(host, cfg.ECDSATLS, &HostToECDSACertificate, &HostECDSAClientAuth)
+               loadCertificates(host, cfg.EdDSATLS, &HostToEdDSACertificate, &HostEdDSAClientAuth)
                loadCertificates(host, cfg.GOSTTLS, &HostToGOSTCertificate, &HostGOSTClientAuth)
        }
 }