Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use bufio for mse write #960

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 88 additions & 104 deletions mse/mse.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package mse

import (
"bufio"
"bytes"
"crypto/rand"
"crypto/rc4"
Expand All @@ -16,8 +17,6 @@ import (
"math/big"
"strconv"
"sync"

"github.com/anacrolix/missinggo/perf"
)

const (
Expand Down Expand Up @@ -160,12 +159,19 @@ func paddedLeft(b []byte, _len int) []byte {
func (h *handshake) postY(x *big.Int) error {
var y big.Int
y.Exp(&g, x, &p)
return h.postWrite(paddedLeft(y.Bytes(), 96))
return h.write(paddedLeft(y.Bytes(), 96))
}

func (h *handshake) establishS() error {
func (h *handshake) establish() error {
x := newX()
h.postY(&x)
if err := h.postY(&x); err != nil {
return err
}

if err := h.w.Flush(); err != nil {
return err
}

var b [96]byte
_, err := io.ReadFull(h.conn, b[:])
if err != nil {
Expand Down Expand Up @@ -194,6 +200,7 @@ func newPadLen() int64 {
// Manages state for both initiating and receiving handshakes.
type handshake struct {
conn io.ReadWriter
w *bufio.Writer
s [96]byte
initer bool // Whether we're initiating or receiving.
skeys SecretKeyIter // Skeys we'll accept if receiving.
Expand All @@ -203,80 +210,15 @@ type handshake struct {
chooseMethod CryptoSelector
// Sent to the receiver.
cryptoProvides CryptoMethod

writeMu sync.Mutex
writes [][]byte
writeErr error
writeCond sync.Cond
writeClose bool

writerMu sync.Mutex
writerCond sync.Cond
writerDone bool
}

func (h *handshake) finishWriting() {
h.writeMu.Lock()
h.writeClose = true
h.writeCond.Broadcast()
h.writeMu.Unlock()

h.writerMu.Lock()
for !h.writerDone {
h.writerCond.Wait()
}
h.writerMu.Unlock()
}

func (h *handshake) writer() {
defer func() {
h.writerMu.Lock()
h.writerDone = true
h.writerCond.Broadcast()
h.writerMu.Unlock()
}()
for {
h.writeMu.Lock()
for {
if len(h.writes) != 0 {
break
}
if h.writeClose {
h.writeMu.Unlock()
return
}
h.writeCond.Wait()
}
b := h.writes[0]
h.writes = h.writes[1:]
h.writeMu.Unlock()
_, err := h.conn.Write(b)
if err != nil {
h.writeMu.Lock()
h.writeErr = err
h.writeMu.Unlock()
return
}
}
}

func (h *handshake) postWrite(b []byte) error {
h.writeMu.Lock()
defer h.writeMu.Unlock()
if h.writeErr != nil {
return h.writeErr
}
h.writes = append(h.writes, b)
h.writeCond.Signal()
return nil
func (h *handshake) write(b []byte) error {
_, err := h.w.Write(b)
return err
}

func xor(a, b []byte) (ret []byte) {
max := len(a)
if max > len(b) {
max = len(b)
}
ret = make([]byte, max)
ret = make([]byte, max(len(a), len(b)))
xorInPlace(ret, a, b)
return
}
Expand Down Expand Up @@ -358,8 +300,16 @@ func (h *handshake) newEncrypt(initer bool) *rc4.Cipher {
}

func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err error) {
h.postWrite(hash(req1, h.s[:]))
h.postWrite(xor(hash(req2, h.skey), hash(req3, h.s[:])))
err = h.write(hash(req1, h.s[:]))
if err != nil {
return
}

err = h.write(xor(hash(req2, h.skey), hash(req3, h.s[:])))
if err != nil {
return
}

buf := &bytes.Buffer{}
padLen := uint16(newPadLen())
if len(h.ia) > math.MaxUint16 {
Expand All @@ -373,7 +323,16 @@ func (h *handshake) initerSteps() (ret io.ReadWriter, selected CryptoMethod, err
e := h.newEncrypt(true)
be := make([]byte, buf.Len())
e.XORKeyStream(be, buf.Bytes())
h.postWrite(be)
err = h.write(be)
if err != nil {
return
}

err = h.w.Flush()
if err != nil {
return
}

bC := h.newEncrypt(false)
var eVC [8]byte
bC.XORKeyStream(eVC[:], vc[:])
Expand Down Expand Up @@ -465,22 +424,34 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
return
}
var lenIA uint16
unmarshal(r, &lenIA)
if err = unmarshal(r, &lenIA); err != nil {
return
}

if lenIA != 0 {
h.ia = make([]byte, lenIA)
unmarshal(r, h.ia)
err = unmarshal(r, h.ia)
if err != nil {
return
}
}

buf := &bytes.Buffer{}
w := cipherWriter{h.newEncrypt(false), buf, nil}
padLen = uint16(newPadLen())
err = marshal(&w, &vc, uint32(chosen), padLen, zeroPad[:padLen])
if err != nil {
return
}
err = h.postWrite(buf.Bytes())
err = h.write(buf.Bytes())
if err != nil {
return
}
err = h.w.Flush()
if err != nil {
return
}

switch chosen {
case CryptoMethodRC4:
ret = readWriter{
Expand All @@ -498,66 +469,79 @@ func (h *handshake) receiverSteps() (ret io.ReadWriter, chosen CryptoMethod, err
return
}

func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
h.writeCond.L = &h.writeMu
h.writerCond.L = &h.writerMu
go h.writer()
defer func() {
h.finishWriting()
if err == nil {
err = h.writeErr
type buffer struct {
B []byte
}

var pool = sync.Pool{
New: func() any {
return &buffer{
B: make([]byte, maxPadLen),
}
}()
err = h.establishS()
},
}

func (h *handshake) Do() (ret io.ReadWriter, method CryptoMethod, err error) {
err = h.establish()
if err != nil {
err = fmt.Errorf("error while establishing secret: %w", err)
return
}
pad := make([]byte, newPadLen())
io.ReadFull(rand.Reader, pad)
err = h.postWrite(pad)

pad := pool.Get().(*buffer)
defer pool.Put(pad)

size := newPadLen()
_, err = io.ReadFull(rand.Reader, pad.B[:size])
if err != nil {
panic(fmt.Sprintln("unexpected error when reading from random", err))
}

err = h.write(pad.B[:size])
if err != nil {
return
}

if h.initer {
ret, method, err = h.initerSteps()
} else {
ret, method, err = h.receiverSteps()
}

return
}

func InitiateHandshake(
rw io.ReadWriter, skey, initialPayload []byte, cryptoProvides CryptoMethod,
) (
func InitiateHandshake(rw io.ReadWriter, key, initialPayload []byte, cryptoProvides CryptoMethod) (
ret io.ReadWriter, method CryptoMethod, err error,
) {
h := handshake{
conn: rw,
w: bufio.NewWriter(rw),
initer: true,
skey: skey,
skey: key,
ia: initialPayload,
cryptoProvides: cryptoProvides,
}
defer perf.ScopeTimerErr(&err)()

return h.Do()
}

func ReceiveHandshake(rw io.ReadWriter, keys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
res := ReceiveHandshakeEx(rw, keys, selectCrypto)
return res.ReadWriter, res.CryptoMethod, res.error
}

type HandshakeResult struct {
io.ReadWriter
CryptoMethod
error
SecretKey []byte
}

func ReceiveHandshake(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (io.ReadWriter, CryptoMethod, error) {
res := ReceiveHandshakeEx(rw, skeys, selectCrypto)
return res.ReadWriter, res.CryptoMethod, res.error
}

func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto CryptoSelector) (ret HandshakeResult) {
h := handshake{
conn: rw,
w: bufio.NewWriter(rw),
initer: false,
skeys: skeys,
chooseMethod: selectCrypto,
Expand All @@ -567,7 +551,7 @@ func ReceiveHandshakeEx(rw io.ReadWriter, skeys SecretKeyIter, selectCrypto Cryp
return
}

// A function that given a function, calls it with secret keys until it
// SecretKeyIter is a function that given a function, calls it with secret keys until it
// returns false or exhausted.
type SecretKeyIter func(callback func(skey []byte) (more bool))

Expand Down