在Go中从服务器获取EOF作为客户端

问题描述:

I have some a Go client for a custom protocol. The protocol is lz4-compressed JSON-RPC with a four byte header giving the length of the compressed JSON.

func ReceiveMessage(conn net.Conn) ([]byte, error) {
    start := time.Now()
    bodyLen := 0
    body := make([]byte, 0, 4096)
    buf := make([]byte, 0, 256)

    for bodyLen == 0 || len(body) < bodyLen {
        if len(body) > 4 {
            header := body[:4]
            body = body[:4]
            bodyLen = int(unpack(header))
        }

        n, err := conn.Read(buf[:])
        if err != nil {
            if err != io.EOF {
                return body, err
            }
        }
        body = append(body, buf[0:n]...)

        now := time.Now()
        if now.Sub(start) > time.Duration(readTimeout) * time.Millisecond     {
            return body, fmt.Errorf("Timed-out while reading from socket.")
        }
        time.Sleep(time.Duration(1) * time.Millisecond)
    }

    return lz4.Decode(nil, body)
}

The client:

func main() {
    address := os.Args[1]
    msg := []byte(os.Args[2])

    fmt.Printf("Sending %s to %s
", msg, address)

    conn, err := net.Dial(address)
    if err != nil {
        fmt.Printf("%v
", err)
        return
    }

    // Another library call
    _, err = SendMessage(conn, []byte(msg))
    if err != nil {
        fmt.Printf("%v
", err)
        return
    }

    response, err := ReceiveMessage(conn)
    conn.Close()

    if err != nil {
        fmt.Printf("%v
", err)
        return
    }

    fmt.Printf("Response: %s
", response)
}

When I call it, I get no response and it just times out. (If I do not explicitly ignore the EOF, it returns there with io.EOF error.) I have another library for this written in Python that also works against the same endpoint with the same payload. Do you see anything immediately?

[JimB just beat me to an answer but here goes anyway.]

The root issue is that you did body = body[:4] when you wanted body = body[4:]. The former keeps only the first four header bytes while the latter tosses the four header bytes just decoded.

Here is a self contained version with some debug logs that works. It has some of the other changes I mentioned. (I guessed at various things that you didn't include, like the lz4 package used, the timeout, unpack, etc.)

package main

import (
        "encoding/binary"
        "errors"
        "fmt"
        "io"
        "log"
        "net"
        "time"

        "github.com/bkaradzic/go-lz4"
)

const readTimeout = 30 * time.Second // XXX guess

func ReceiveMessage(conn net.Conn) ([]byte, error) {
        bodyLen := 0
        body := make([]byte, 0, 4096)
        var buf [256]byte

        conn.SetDeadline(time.Now().Add(readTimeout))
        defer conn.SetDeadline(time.Time{}) // disable deadline
        for bodyLen == 0 || len(body) < bodyLen {
                if bodyLen == 0 && len(body) >= 4 {
                        bodyLen = int(unpack(body[:4]))
                        body = body[4:]
                        if bodyLen <= 0 {
                                return nil, errors.New("invalid body length")
                        }
                        log.Println("read bodyLen:", bodyLen)
                        continue
                }

                n, err := conn.Read(buf[:])
                body = append(body, buf[:n]...)
                log.Printf("appended %d bytes, len(body) now %d", n, len(body))
                // Note, this is checked *after* handing any n bytes.
                // An io.Reader is allowed to return data with an error.
                if err != nil {
                        if err != io.EOF {
                                return nil, err
                        }
                        break
                }
        }
        if len(body) != bodyLen {
                return nil, fmt.Errorf("got %d bytes, expected %d",
                        len(body), bodyLen)
        }

        return lz4.Decode(nil, body)
}

const address = ":5678"

var msg = []byte(`{"foo":"bar"}`)

func main() {
        //address := os.Args[1]
        //msg := []byte(os.Args[2])

        fmt.Printf("Sending %s to %s
", msg, address)

        conn, err := net.Dial("tcp", address)
        if err != nil {
                fmt.Printf("%v
", err)
                return
        }

        // Another library call
        _, err = SendMessage(conn, msg)
        if err != nil {
                fmt.Printf("%v
", err)
                return
        }

        response, err := ReceiveMessage(conn)
        conn.Close()

        if err != nil {
                fmt.Printf("%v
", err)
                return
        }

        fmt.Printf("Response: %s
", response)
}

// a guess at what your `unpack` does
func unpack(b []byte) uint32 {
        return binary.LittleEndian.Uint32(b)
}

func SendMessage(net.Conn, []byte) (int, error) {
        // stub
        return 0, nil
}

func init() {
        // start a simple test server in the same process as a go-routine.
        ln, err := net.Listen("tcp", address)
        if err != nil {
                log.Fatal(err)
        }
        go func() {
                defer ln.Close()
                for {
                        conn, err := ln.Accept()
                        if err != nil {
                                log.Fatalln("accept:", err)
                        }
                        go Serve(conn)
                }
        }()
}

func Serve(c net.Conn) {
        defer c.Close()
        // skip readding the initial request/message and just respond
        const response = `{"somefield": "someval"}`
        // normally (de)compression in Go is done streaming via
        // an io.Reader or io.Writer but we need the final length.
        data, err := lz4.Encode(nil, []byte(response))
        if err != nil {
                log.Println("lz4 encode:", err)
                return
        }
        log.Println("sending len:", len(data))
        if err = binary.Write(c, binary.LittleEndian, uint32(len(data))); err != nil {
                log.Println("writing len:", err)
                return
        }
        log.Println("sending data")
        if _, err = c.Write(data); err != nil {
                log.Println("writing compressed response:", err)
                return
        }
        log.Println("Serve done, closing connection")
}

Playground (but not runnable there).

You have a number of issues with the server code. Without a full reproducing case, it's hard to tell if these will fix everything.

    for bodyLen == 0 || len(body) < bodyLen {
        if len(body) > 4 {
            header := body[:4]
            body = body[:4]
            bodyLen = int(unpack(header))
        }

every iteration, if len(body) > 4, you slice body back to the first 4 bytes. Body might never get to be >= bodyLen.

        n, err := conn.Read(buf[:])

You don't need to re-slice buf here, use conn.Read(buf)

        if err != nil {
            if err != io.EOF {
                return body, err
            }
        }

io.EOF is the end of the stream, and you need to handle it. Note that n might still be > 0 when you get an EOF. Check after processing the body for io.EOF or you could loop indefinitely.

        body = append(body, buf[0:n]...)

        now := time.Now()
        if now.Sub(start) > time.Duration(readTimeout) * time.Millisecond     {
            return body, fmt.Errorf("Timed-out while reading from socket.")

you would be better off using conn.SetReadDeadline before each read, so a stalled Read could be interrupted.