]> Sergey Matveev's repositories - btrtrc.git/blobdiff - mse/mse.go
Set the connection.cryptoMethod
[btrtrc.git] / mse / mse.go
index d843d8f93f12b89f87e492d38caef03a4d9d0e64..83454b82e510cb22a03245330393eb94f2c9ed09 100644 (file)
@@ -367,7 +367,7 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
        return newEncrypt(initer, h.s[:], h.skey)
 }
 
-func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
+func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
        h.postWrite(hash(req1, h.s[:]))
        h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
        buf := &bytes.Buffer{}
@@ -409,7 +409,8 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
        if err != nil {
                return
        }
-       switch method & h.cryptoProvides {
+       selected = method & h.cryptoProvides
+       switch selected {
        case CryptoMethodRC4:
                ret = readWriter{r, &cipherWriter{e, h.conn, nil}}
        case CryptoMethodPlaintext:
@@ -422,7 +423,7 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, err error) {
 
 var ErrNoSecretKeyMatch = errors.New("no skey matched")
 
-func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
+func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err error) {
        // There is up to 512 bytes of padding, then the 20 byte hash.
        err = readUntil(io.LimitReader(h.conn, 532), hash(req1, h.s[:]))
        if err != nil {
@@ -460,7 +461,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
                return
        }
        cryptoProvidesCount.Add(strconv.FormatUint(uint64(provides), 16), 1)
-       chosen := h.chooseMethod(provides)
+       chosen = h.chooseMethod(provides)
        _, err = io.CopyN(ioutil.Discard, r, int64(padLen))
        if err != nil {
                return
@@ -499,7 +500,7 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, err error) {
        return
 }
 
-func (h *handshake) Do() (ret io.ReadWriter, err error) {
+func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
        h.writeCond.L = &h.writeMu
        h.writerCond.L = &h.writerMu
        go h.writer()
@@ -521,14 +522,14 @@ func (h *handshake) Do() (ret io.ReadWriter, err error) {
                return
        }
        if h.initer {
-               ret, err = h.initerSteps()
+               ret, method, err = h.initerSteps()
        } else {
-               ret, err = h.receiverSteps()
+               ret, method, err = h.receiverSteps()
        }
        return
 }
 
-func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, err error) {
+func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cryptoProvides CryptoMethod) (ret io.ReadWriter, method CryptoMethod, err error) {
        h := handshake{
                conn:           rw,
                initer:         true,
@@ -539,7 +540,7 @@ func InitiateHandshake(rw io.ReadWriter, skey []byte, initialPayload []byte, cry
        return h.Do()
 }
 
-func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, err error) {
+func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret io.ReadWriter, method CryptoMethod, err error) {
        h := handshake{
                conn:         rw,
                initer:       false,