]> Sergey Matveev's repositories - godlighty.git/blob - tls.go
Missing copying
[godlighty.git] / tls.go
1 /*
2 godlighty -- highly-customizable HTTP, HTTP/2, HTTPS server
3 Copyright (C) 2021-2022 Sergey Matveev <stargrave@stargrave.org>
4
5 This program is free software: you can redistribute it and/or modify
6 it under the terms of the GNU General Public License as published by
7 the Free Software Foundation, version 3 of the License.
8
9 This program is distributed in the hope that it will be useful,
10 but WITHOUT ANY WARRANTY; without even the implied warranty of
11 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12 GNU General Public License for more details.
13
14 You should have received a copy of the GNU General Public License
15 along with this program.  If not, see <http://www.gnu.org/licenses/>.
16 */
17
18 package godlighty
19
20 import (
21         "crypto/tls"
22         "crypto/x509"
23         "encoding/pem"
24         "errors"
25         "fmt"
26         "log"
27         "os"
28 )
29
30 var (
31         NextProtos = []string{"h2", "http/1.1"}
32
33         HostToCertificate map[string]*tls.Certificate
34         HostClientAuth    map[string]*x509.CertPool
35
36         HostToGOSTCertificate map[string]*tls.Certificate
37         HostGOSTClientAuth    map[string]*x509.CertPool
38 )
39
40 func GetCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
41         if CHIHasGOST(chi) {
42                 if cert := HostToGOSTCertificate[chi.ServerName]; cert != nil {
43                         return cert, nil
44                 }
45         }
46         cert := HostToCertificate[chi.ServerName]
47         if cert == nil {
48                 return nil, errors.New("no certificate found")
49         }
50         return cert, nil
51 }
52
53 func GetConfigForClient(chi *tls.ClientHelloInfo) (*tls.Config, error) {
54         var pool *x509.CertPool
55         if CHIHasGOST(chi) {
56                 pool = HostGOSTClientAuth[chi.ServerName]
57         }
58         if pool == nil {
59                 pool = HostClientAuth[chi.ServerName]
60         }
61         if pool == nil {
62                 return nil, nil
63         }
64         return &tls.Config{
65                 GetCertificate: GetCertificate,
66                 NextProtos:     NextProtos,
67                 ClientCAs:      pool,
68                 ClientAuth:     tls.RequireAndVerifyClientCert,
69         }, nil
70 }
71
72 func loadCertificates(
73         host string,
74         cfg *TLSCfg,
75         hostToCertificate *map[string]*tls.Certificate,
76         hostClientAuth *map[string]*x509.CertPool,
77 ) {
78         if cfg == nil {
79                 return
80         }
81         cert, err := tls.LoadX509KeyPair(cfg.Cert, cfg.Key)
82         if err != nil {
83                 log.Fatalln(err)
84         }
85         if cfg.CACert != "" {
86                 data, err := os.ReadFile(cfg.CACert)
87                 if err != nil {
88                         log.Fatalln(err)
89                 }
90                 block, _ := pem.Decode(data)
91                 if block == nil {
92                         log.Fatalln(fmt.Errorf("no PEM found: %s", cfg.CACert))
93                 }
94                 if block.Type != "CERTIFICATE" {
95                         log.Fatalln(fmt.Errorf("non CERTIFICATE: %s", cfg.CACert))
96                 }
97                 cert.Certificate = append(cert.Certificate, block.Bytes)
98         }
99         (*hostToCertificate)[host] = &cert
100         pool := x509.NewCertPool()
101         for _, p := range cfg.ClientCAs {
102                 data, err := os.ReadFile(p)
103                 if err != nil {
104                         log.Fatalln(err)
105                 }
106                 var block *pem.Block
107                 for len(data) > 0 {
108                         block, data = pem.Decode(data)
109                         if block == nil {
110                                 log.Fatalln("can not decode PEM:", p)
111                         }
112                         if block.Type != "CERTIFICATE" {
113                                 continue
114                         }
115                         ca, err := x509.ParseCertificate(block.Bytes)
116                         if err != nil {
117                                 log.Fatalln(err)
118                         }
119                         pool.AddCert(ca)
120                 }
121         }
122         if len(pool.Subjects()) > 0 {
123                 (*hostClientAuth)[host] = pool
124         }
125 }
126
127 func LoadCertificates() {
128         HostToCertificate = make(map[string]*tls.Certificate, len(Hosts))
129         HostClientAuth = make(map[string]*x509.CertPool)
130         HostToGOSTCertificate = make(map[string]*tls.Certificate, len(Hosts))
131         HostGOSTClientAuth = make(map[string]*x509.CertPool)
132         for host, cfg := range Hosts {
133                 loadCertificates(host, cfg.TLS, &HostToCertificate, &HostClientAuth)
134                 loadCertificates(host, cfg.GOSTTLS, &HostToGOSTCertificate, &HostGOSTClientAuth)
135         }
136 }
137
138 func NewTLSConfig() *tls.Config {
139         return &tls.Config{
140                 NextProtos:         NextProtos,
141                 GetCertificate:     GetCertificate,
142                 GetConfigForClient: GetConfigForClient,
143         }
144 }