import (
"crypto/tls"
+ "crypto/x509"
"encoding/pem"
"errors"
"fmt"
"log"
)
-var HostToCertificate map[string]*tls.Certificate
+var (
+ NextProtos = []string{"h2", "http/1.1"}
+
+ HostToCertificate map[string]*tls.Certificate
+ HostClientAuth map[string]*x509.CertPool
+
+ HostToGOSTCertificate map[string]*tls.Certificate
+ HostGOSTClientAuth map[string]*x509.CertPool
+)
func GetCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
+ if CHIHasGOST(chi) {
+ if cert := HostToGOSTCertificate[chi.ServerName]; cert != nil {
+ return cert, nil
+ }
+ }
cert := HostToCertificate[chi.ServerName]
if cert == nil {
return nil, errors.New("no certificate found")
return cert, nil
}
-func LoadCertificates() {
- HostToCertificate = make(map[string]*tls.Certificate, len(Hosts))
- for host, cfg := range Hosts {
- if cfg.TLS == nil {
- continue
+func GetConfigForClient(chi *tls.ClientHelloInfo) (*tls.Config, error) {
+ var pool *x509.CertPool
+ if CHIHasGOST(chi) {
+ pool = HostGOSTClientAuth[chi.ServerName]
+ }
+ if pool == nil {
+ pool = HostClientAuth[chi.ServerName]
+ }
+ if pool == nil {
+ return nil, nil
+ }
+ return &tls.Config{
+ GetCertificate: GetCertificate,
+ NextProtos: NextProtos,
+ ClientCAs: pool,
+ ClientAuth: tls.RequireAndVerifyClientCert,
+ }, nil
+}
+
+func loadCertificates(
+ host string,
+ cfg *TLSCfg,
+ hostToCertificate *map[string]*tls.Certificate,
+ hostClientAuth *map[string]*x509.CertPool,
+) {
+ if cfg == nil {
+ return
+ }
+ cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
+ if err != nil {
+ log.Fatalln(err)
+ }
+ if cfg.CACert != "" {
+ data, err := ioutil.ReadFile(cfg.CACert)
+ if err != nil {
+ log.Fatalln(err)
+ }
+ block, _ := pem.Decode(data)
+ if block == nil {
+ log.Fatalln(fmt.Errorf("no PEM found: %s", cfg.CACert))
}
- cert, err := tls.LoadX509KeyPair(cfg.TLS.Cert, cfg.TLS.Key)
+ if block.Type != "CERTIFICATE" {
+ log.Fatalln(fmt.Errorf("non CERTIFICATE: %s", cfg.CACert))
+ }
+ cert.Certificate = append(cert.Certificate, block.Bytes)
+ }
+ (*hostToCertificate)[host] = &cert
+ pool := x509.NewCertPool()
+ for _, p := range cfg.ClientCAs {
+ data, err := ioutil.ReadFile(p)
if err != nil {
log.Fatalln(err)
}
- if cfg.TLS.CACert != "" {
- data, err := ioutil.ReadFile(cfg.TLS.CACert)
- if err != nil {
- log.Fatalln(err)
- }
- block, _ := pem.Decode(data)
+ var block *pem.Block
+ for len(data) > 0 {
+ block, data = pem.Decode(data)
if block == nil {
- log.Fatalln(fmt.Errorf("no PEM found: %s", cfg.TLS.CACert))
+ log.Fatalln("can not decode PEM:", p)
}
if block.Type != "CERTIFICATE" {
- log.Fatalln(fmt.Errorf("non CERTIFICATE: %s", cfg.TLS.CACert))
+ continue
}
- cert.Certificate = append(cert.Certificate, block.Bytes)
+ ca, err := x509.ParseCertificate(block.Bytes)
+ if err != nil {
+ log.Fatalln(err)
+ }
+ pool.AddCert(ca)
}
- HostToCertificate[host] = &cert
+ }
+ if len(pool.Subjects()) > 0 {
+ (*hostClientAuth)[host] = pool
+ }
+}
+
+func LoadCertificates() {
+ HostToCertificate = make(map[string]*tls.Certificate, len(Hosts))
+ HostClientAuth = make(map[string]*x509.CertPool)
+ HostToGOSTCertificate = make(map[string]*tls.Certificate, len(Hosts))
+ HostGOSTClientAuth = make(map[string]*x509.CertPool)
+ for host, cfg := range Hosts {
+ loadCertificates(host, cfg.TLS, &HostToCertificate, &HostClientAuth)
+ loadCertificates(host, cfg.GOSTTLS, &HostToGOSTCertificate, &HostGOSTClientAuth)
+ }
+}
+
+func NewTLSConfig() *tls.Config {
+ return &tls.Config{
+ NextProtos: NextProtos,
+ GetCertificate: GetCertificate,
+ GetConfigForClient: GetConfigForClient,
}
}