src/pkg/rpc/client.go | 23 ++++++++++++----------- src/pkg/rpc/server_test.go | 26 ++++++++++++++++++++++++++ diff --git a/src/pkg/rpc/client.go b/src/pkg/rpc/client.go index c77901c6dca8f9c946a4ca7ee043cf0f87667dc1..3dc6df1c4b3096c21424a4fc3441237afc631a57 100644 --- a/src/pkg/rpc/client.go +++ b/src/pkg/rpc/client.go @@ -85,7 +85,8 @@ defer client.sending.Unlock() client.request.Seq = c.seq client.request.ServiceMethod = c.ServiceMethod if err := client.codec.WriteRequest(&client.request, c.Args); err != nil { - panic("rpc: client encode error: " + err.String()) + c.Error = err + c.done() } } @@ -251,10 +252,10 @@ // the invocation. The done channel will signal when the call is complete by returning // the same Call object. If done is nil, Go will allocate a new channel. // If non-nil, done must be buffered or Go will deliberately crash. func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call { - c := new(Call) - c.ServiceMethod = serviceMethod - c.Args = args - c.Reply = reply + call := new(Call) + call.ServiceMethod = serviceMethod + call.Args = args + call.Reply = reply if done == nil { done = make(chan *Call, 10) // buffered. } else { @@ -266,14 +267,14 @@ if cap(done) == 0 { log.Panic("rpc: done channel is unbuffered") } } - c.Done = done + call.Done = done if client.shutdown { - c.Error = ErrShutdown - c.done() - return c + call.Error = ErrShutdown + call.done() + return call } - client.send(c) - return c + client.send(call) + return call } // Call invokes the named function, waits for it to complete, and returns its error status. diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go index cb2db2a65d34670a9f94f2b0b9dad7157cd733c3..029741b28b51a85efa6e1ab4d594db5de852d0c4 100644 --- a/src/pkg/rpc/server_test.go +++ b/src/pkg/rpc/server_test.go @@ -467,6 +467,32 @@ func TestCountMallocsOverHTTP(t *testing.T) { fmt.Printf("mallocs per HTTP rpc round trip: %d\n", countMallocs(dialHTTP, t)) } +type writeCrasher struct{} + +func (writeCrasher) Close() os.Error { + return nil +} + +func (writeCrasher) Read(p []byte) (int, os.Error) { + return 0, os.EOF +} + +func (writeCrasher) Write(p []byte) (int, os.Error) { + return 0, os.NewError("fake write failure") +} + +func TestClientWriteError(t *testing.T) { + c := NewClient(writeCrasher{}) + res := false + err := c.Call("foo", 1, &res) + if err == nil { + t.Fatal("expected error") + } + if err.String() != "fake write failure" { + t.Error("unexpected value of error:", err) + } +} + func benchmarkEndToEnd(dial func() (*Client, os.Error), b *testing.B) { b.StopTimer() once.Do(startServer)