]> Sergey Matveev's repositories - godlighty.git/blobdiff - tls.go
Authentication and authorization
[godlighty.git] / tls.go
diff --git a/tls.go b/tls.go
index 9ac600c364e25a4ed1a9f853d784d3fe9f9a28c4..88f1dc38ff2d458477a77f068c8655aaadb38d7c 100644 (file)
--- a/tls.go
+++ b/tls.go
@@ -19,6 +19,7 @@ package godlighty
 
 import (
        "crypto/tls"
+       "crypto/x509"
        "encoding/pem"
        "errors"
        "fmt"
@@ -26,7 +27,11 @@ import (
        "log"
 )
 
-var HostToCertificate map[string]*tls.Certificate
+var (
+       NextProtos        = []string{"h2", "http/1.1"}
+       HostToCertificate map[string]*tls.Certificate
+       HostClientAuth    map[string]*x509.CertPool
+)
 
 func GetCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
        cert := HostToCertificate[chi.ServerName]
@@ -36,8 +41,22 @@ func GetCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
        return cert, nil
 }
 
+func GetConfigForClient(chi *tls.ClientHelloInfo) (*tls.Config, error) {
+       pool := HostClientAuth[chi.ServerName]
+       if pool == nil {
+               return nil, nil
+       }
+       return &tls.Config{
+               GetCertificate: GetCertificate,
+               NextProtos:     NextProtos,
+               ClientCAs:      pool,
+               ClientAuth:     tls.RequireAndVerifyClientCert,
+       }, nil
+}
+
 func LoadCertificates() {
        HostToCertificate = make(map[string]*tls.Certificate, len(Hosts))
+       HostClientAuth = make(map[string]*x509.CertPool)
        for host, cfg := range Hosts {
                if cfg.TLS == nil {
                        continue
@@ -61,5 +80,38 @@ func LoadCertificates() {
                        cert.Certificate = append(cert.Certificate, block.Bytes)
                }
                HostToCertificate[host] = &cert
+               pool := x509.NewCertPool()
+               for _, p := range cfg.TLS.ClientCAs {
+                       data, err := ioutil.ReadFile(p)
+                       if err != nil {
+                               log.Fatalln(err)
+                       }
+                       var block *pem.Block
+                       for len(data) > 0 {
+                               block, data = pem.Decode(data)
+                               if block == nil {
+                                       log.Fatalln("can not decode PEM:", p)
+                               }
+                               if block.Type != "CERTIFICATE" {
+                                       continue
+                               }
+                               ca, err := x509.ParseCertificate(block.Bytes)
+                               if err != nil {
+                                       log.Fatalln(err)
+                               }
+                               pool.AddCert(ca)
+                       }
+               }
+               if len(pool.Subjects()) > 0 {
+                       HostClientAuth[host] = pool
+               }
+       }
+}
+
+func NewTLSConfig() *tls.Config {
+       return &tls.Config{
+               NextProtos:         NextProtos,
+               GetCertificate:     GetCertificate,
+               GetConfigForClient: GetConfigForClient,
        }
 }