]> Sergey Matveev's repositories - tofuproxy.git/blobdiff - tls/verify.go
bytes.Equal instead of bytes.Compare
[tofuproxy.git] / tls / verify.go
index 7bff1c185f7c825d976d46fe9852dd4bae50c5ec..2df0c24e5bcf7bc3984244df2676fae6bab71f72 100644 (file)
@@ -24,7 +24,9 @@ import (
        "crypto/x509"
        "encoding/hex"
        "encoding/pem"
+       "errors"
        "fmt"
+       "io/fs"
        "log"
        "os"
        "os/exec"
@@ -206,6 +208,30 @@ func verifyCert(
                        fifos.LogDANE <- fmt.Sprintf("%s\tNAK", host)
                }
        }
+       if len(verifiedChains) > 0 {
+               caHashes := make(map[string]struct{})
+               for _, certs := range verifiedChains {
+                       for _, cert := range certs {
+                               caHashes[spkiHash(cert)] = struct{}{}
+                       }
+               }
+               var restrictedHosts []string
+               caches.RestrictedM.RLock()
+               for h := range caHashes {
+                       restrictedHosts = append(restrictedHosts, caches.Restricted[h]...)
+               }
+               caches.RestrictedM.RUnlock()
+               if len(restrictedHosts) > 0 {
+                       for _, h := range restrictedHosts {
+                               if host == h || strings.HasSuffix(host, "."+h) {
+                                       goto HostIsNotRestricted
+                               }
+                       }
+                       fifos.LogCert <- fmt.Sprintf("Restricted\t%s", host)
+                       return ErrRejected{host}
+               }
+       }
+HostIsNotRestricted:
        fn := filepath.Join(Certs, host)
        certsOur, _, err := ucspi.CertPoolFromFile(fn)
        if err == nil || dialErr != nil || (daneExists && !daneMatched) {
@@ -213,18 +239,18 @@ func verifyCert(
                        caches.AcceptedM.Lock()
                        caches.Accepted[host] = certTheirHash
                        caches.AcceptedM.Unlock()
-                       if bytes.Compare(certsOur[0].Raw, rawCerts[0]) != 0 {
+                       if !bytes.Equal(certsOur[0].Raw, rawCerts[0]) {
                                fifos.LogCert <- fmt.Sprintf("Refresh\t%s\t%s", host, certTheirHash)
                                goto CertUpdate
                        }
                        return nil
                }
                var b bytes.Buffer
-               b.WriteString(fmt.Sprintf("set host \"%s\"\n", host))
+               fmt.Fprintf(&b, "set host \"%s\"\n", host)
                if dialErr == nil {
-                       b.WriteString(fmt.Sprintf("set err \"\"\n"))
+                       fmt.Fprintf(&b, "set err \"\"\n")
                } else {
-                       b.WriteString(fmt.Sprintf("set err \"%s\"\n", dialErr.Error()))
+                       fmt.Fprintf(&b, "set err \"%s\"\n", dialErr.Error())
                }
                var daneStatus string
                if daneExists {
@@ -234,21 +260,17 @@ func verifyCert(
                                daneStatus = "bad"
                        }
                }
-               b.WriteString(fmt.Sprintf("set daneStatus \"%s\"\n", daneStatus))
+               fmt.Fprintf(&b, "set daneStatus \"%s\"\n", daneStatus)
                hexCerts := make([]string, 0, len(rawCerts))
                for _, rawCert := range rawCerts {
                        hexCerts = append(hexCerts, hex.EncodeToString([]byte(certInfo(rawCert))))
                }
-               b.WriteString(fmt.Sprintf(
-                       "set certsTheir \"%s\"\n", strings.Join(hexCerts, " "),
-               ))
+               fmt.Fprintf(&b, "set certsTheir \"%s\"\n", strings.Join(hexCerts, " "))
                hexCerts = make([]string, 0, len(certsOur))
                for _, cert := range certsOur {
                        hexCerts = append(hexCerts, hex.EncodeToString([]byte(certInfo(cert.Raw))))
                }
-               b.WriteString(fmt.Sprintf(
-                       "set certsOur \"%s\"\n", strings.Join(hexCerts, " "),
-               ))
+               fmt.Fprintf(&b, "set certsOur \"%s\"\n", strings.Join(hexCerts, " "))
                b.WriteString(VerifyDialog)
                cmd := exec.Command(CmdWish)
                // os.WriteFile("/tmp/verify-dialog.tcl", b.Bytes(), 0666)
@@ -279,7 +301,7 @@ func verifyCert(
                        return ErrRejected{host}
                }
        } else {
-               if !os.IsNotExist(err) {
+               if !errors.Is(err, fs.ErrNotExist) {
                        return err
                }
                fifos.LogCert <- fmt.Sprintf("TOFU\t%s\t%s", host, certTheirHash)