diff --git a/pkg/cover/instrument.go b/pkg/cover/instrument.go index c313ba5..4cbcb70 100644 --- a/pkg/cover/instrument.go +++ b/pkg/cover/instrument.go @@ -45,6 +45,8 @@ package main import ( "bufio" + "bytes" + "encoding/json" "fmt" "io" "io/ioutil" @@ -52,10 +54,12 @@ import ( "net" "net/http" "os" + "os/signal" + "path/filepath" "strings" "sync/atomic" + "syscall" "testing" - "path/filepath" _cover {{.GlobalCoverVarImportPath | printf "%q"}} @@ -137,6 +141,23 @@ func registerHandlers() { log.Fatalf("register address %v failed, err: %v, response: %v", profileAddr, err, string(resp)) } + fn := func() { + var ( + err error + profileAddrs []string + addresses []string + ) + if addresses, err = getAllHosts(ln); err != nil { + log.Fatalf("get all host failed, err: %v", err) + return + } + for _, addr := range addresses { + profileAddrs = append(profileAddrs, "http://"+addr) + } + deregisterSelf(profileAddrs) + } + go watchSignal(fn) + mux := http.NewServeMux() // Coverage reports the current code coverage as a fraction in the range [0, 1]. // If coverage is not enabled, Coverage returns 0. @@ -225,6 +246,64 @@ func registerSelf(address string) ([]byte, error) { return body, err } +func deregisterSelf(address []string) ([]byte, error) { + param := map[string]interface{}{ + "address": address, + } + jsonBody, err := json.Marshal(param) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", fmt.Sprintf("%s/v1/cover/remove", {{.Center | printf "%q"}}), bytes.NewReader(jsonBody)) + if err != nil { + log.Fatalf("http.NewRequest failed: %v", err) + return nil, err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil && isNetworkError(err) { + log.Printf("[goc][WARN]error occurred:%v, try again", err) + resp, err = http.DefaultClient.Do(req) + } + if err != nil { + return nil, fmt.Errorf("failed to register into coverage center, err:%v", err) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body, err:%v", err) + } + + if resp.StatusCode != 200 { + err = fmt.Errorf("failed to register into coverage center, response code %d", resp.StatusCode) + } + + return body, err +} + +type CallbackFunc func() + +func watchSignal(fn CallbackFunc) { + defer fn() + + // init signal + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGHUP, syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT) + for { + si := <-c + log.Printf("get a signal %s", si.String()) + switch si { + case syscall.SIGQUIT, syscall.SIGTERM, syscall.SIGINT: + return + case syscall.SIGHUP: + default: + return + } + } +} + func isNetworkError(err error) bool { if err == io.EOF { return true @@ -290,6 +369,22 @@ func getRealHost(ln net.Listener) (host string, err error) { return } +func getAllHosts(ln net.Listener) (hosts []string, err error) { + adds, err := net.InterfaceAddrs() + if err != nil { + return + } + + var host string + for _, addr := range adds { + if ipNet, ok := addr.(*net.IPNet); ok && ipNet.IP.To4() != nil { + host = fmt.Sprintf("%s:%d", ipNet.IP.String(), ln.Addr().(*net.TCPAddr).Port) + hosts = append(hosts, host) + } + } + return +} + func getPreviousAddr() string { file, err := os.Open(os.Args[0] + "_profile_listen_addr") if err != nil {