X-Git-Url: http://www.git.stargrave.org/?a=blobdiff_plain;f=tls.go;h=e3f62a9d74ebfb5cad8540cfbf993fcdae643a85;hb=16a92e8c1ea2a890d841019761be5c9f6b334f7a;hp=88f1dc38ff2d458477a77f068c8655aaadb38d7c;hpb=241d153049750166c970dc644e1cf14b8b3f3509;p=godlighty.git diff --git a/tls.go b/tls.go index 88f1dc3..e3f62a9 100644 --- a/tls.go +++ b/tls.go @@ -28,12 +28,21 @@ import ( ) var ( - NextProtos = []string{"h2", "http/1.1"} + NextProtos = []string{"h2", "http/1.1"} + HostToCertificate map[string]*tls.Certificate HostClientAuth map[string]*x509.CertPool + + HostToGOSTCertificate map[string]*tls.Certificate + HostGOSTClientAuth map[string]*x509.CertPool ) func GetCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + if CHIHasGOST(chi) { + if cert := HostToGOSTCertificate[chi.ServerName]; cert != nil { + return cert, nil + } + } cert := HostToCertificate[chi.ServerName] if cert == nil { return nil, errors.New("no certificate found") @@ -42,7 +51,13 @@ func GetCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { } func GetConfigForClient(chi *tls.ClientHelloInfo) (*tls.Config, error) { - pool := HostClientAuth[chi.ServerName] + var pool *x509.CertPool + if CHIHasGOST(chi) { + pool = HostGOSTClientAuth[chi.ServerName] + } + if pool == nil { + pool = HostClientAuth[chi.ServerName] + } if pool == nil { return nil, nil } @@ -54,58 +69,70 @@ func GetConfigForClient(chi *tls.ClientHelloInfo) (*tls.Config, error) { }, nil } -func LoadCertificates() { - HostToCertificate = make(map[string]*tls.Certificate, len(Hosts)) - HostClientAuth = make(map[string]*x509.CertPool) - for host, cfg := range Hosts { - if cfg.TLS == nil { - continue +func loadCertificates( + host string, + cfg *TLSCfg, + hostToCertificate *map[string]*tls.Certificate, + hostClientAuth *map[string]*x509.CertPool, +) { + if cfg == nil { + return + } + cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key) + if err != nil { + log.Fatalln(err) + } + if cfg.CACert != "" { + data, err := ioutil.ReadFile(cfg.CACert) + if err != nil { + log.Fatalln(err) + } + block, _ := pem.Decode(data) + if block == nil { + log.Fatalln(fmt.Errorf("no PEM found: %s", cfg.CACert)) } - cert, err := tls.LoadX509KeyPair(cfg.TLS.Cert, cfg.TLS.Key) + if block.Type != "CERTIFICATE" { + log.Fatalln(fmt.Errorf("non CERTIFICATE: %s", cfg.CACert)) + } + cert.Certificate = append(cert.Certificate, block.Bytes) + } + (*hostToCertificate)[host] = &cert + pool := x509.NewCertPool() + for _, p := range cfg.ClientCAs { + data, err := ioutil.ReadFile(p) if err != nil { log.Fatalln(err) } - if cfg.TLS.CACert != "" { - data, err := ioutil.ReadFile(cfg.TLS.CACert) - if err != nil { - log.Fatalln(err) - } - block, _ := pem.Decode(data) + var block *pem.Block + for len(data) > 0 { + block, data = pem.Decode(data) if block == nil { - log.Fatalln(fmt.Errorf("no PEM found: %s", cfg.TLS.CACert)) + log.Fatalln("can not decode PEM:", p) } if block.Type != "CERTIFICATE" { - log.Fatalln(fmt.Errorf("non CERTIFICATE: %s", cfg.TLS.CACert)) + continue } - cert.Certificate = append(cert.Certificate, block.Bytes) - } - HostToCertificate[host] = &cert - pool := x509.NewCertPool() - for _, p := range cfg.TLS.ClientCAs { - data, err := ioutil.ReadFile(p) + ca, err := x509.ParseCertificate(block.Bytes) if err != nil { log.Fatalln(err) } - var block *pem.Block - for len(data) > 0 { - block, data = pem.Decode(data) - if block == nil { - log.Fatalln("can not decode PEM:", p) - } - if block.Type != "CERTIFICATE" { - continue - } - ca, err := x509.ParseCertificate(block.Bytes) - if err != nil { - log.Fatalln(err) - } - pool.AddCert(ca) - } - } - if len(pool.Subjects()) > 0 { - HostClientAuth[host] = pool + pool.AddCert(ca) } } + if len(pool.Subjects()) > 0 { + (*hostClientAuth)[host] = pool + } +} + +func LoadCertificates() { + HostToCertificate = make(map[string]*tls.Certificate, len(Hosts)) + HostClientAuth = make(map[string]*x509.CertPool) + HostToGOSTCertificate = make(map[string]*tls.Certificate, len(Hosts)) + HostGOSTClientAuth = make(map[string]*x509.CertPool) + for host, cfg := range Hosts { + loadCertificates(host, cfg.TLS, &HostToCertificate, &HostClientAuth) + loadCertificates(host, cfg.GOSTTLS, &HostToGOSTCertificate, &HostGOSTClientAuth) + } } func NewTLSConfig() *tls.Config {