diff --git a/cmd/build.go b/cmd/build.go index 197ed1a..1aab87c 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -14,7 +14,7 @@ var buildCmd = &cobra.Command{ } func init() { - buildCmd.Flags().StringVarP(&config.GocConfig.Mode, "mode", "", "count", "coverage mode: set, count, atomic") + buildCmd.Flags().StringVarP(&config.GocConfig.Mode, "mode", "", "count", "coverage mode: set, count, atomic, watch") buildCmd.Flags().StringVarP(&config.GocConfig.Host, "host", "", "127.0.0.1:7777", "specify the host of the goc sever") rootCmd.AddCommand(buildCmd) } diff --git a/pkg/build/tmpfolder.go b/pkg/build/tmpfolder.go index ef83c25..d81c415 100644 --- a/pkg/build/tmpfolder.go +++ b/pkg/build/tmpfolder.go @@ -48,7 +48,7 @@ func tmpFolderName(path string) string { sum := sha256.Sum256([]byte(path)) h := fmt.Sprintf("%x", sum[:6]) - return "goc-build-" + h + return "gocbuild" + h } // skipCopy skip copy .git dir and irregular files diff --git a/pkg/config/config.go b/pkg/config/config.go index e23f541..581698e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -72,7 +72,7 @@ type PackageCover struct { // FileVar holds the name of the generated coverage variables targeting the named file. type FileVar struct { - File string + File string // 这里其实不是文件名,是 importpath + filename Var string } diff --git a/pkg/cover/agent.tpl b/pkg/cover/agent.tpl index 4995f4f..64d32aa 100644 --- a/pkg/cover/agent.tpl +++ b/pkg/cover/agent.tpl @@ -53,11 +53,11 @@ func init() { rpcstreamUrl := fmt.Sprintf("ws://%v/v2/internal/ws/rpcstream?%v", host, v.Encode()) ws, _, err := dialer.Dial(rpcstreamUrl, nil) if err != nil { - log.Printf("[goc][Error] fail to dial to goc server: %v", err) + log.Printf("[goc][Error] rpc fail to dial to goc server: %v", err) time.Sleep(waitDelay) continue } - log.Printf("[goc][Info] connected to goc server") + log.Printf("[goc][Info] rpc connected to goc server") rwc := &ReadWriteCloser{ws: ws} s := rpc.NewServer() @@ -67,7 +67,7 @@ func init() { // exit rpc server, close ws connection ws.Close() time.Sleep(waitDelay) - log.Printf("[goc][Error] connection to goc server broken", ) + log.Printf("[goc][Error] rpc connection to goc server broken", ) } }() } diff --git a/pkg/cover/agentwatch.tpl b/pkg/cover/agentwatch.tpl new file mode 100644 index 0000000..c0ddc89 --- /dev/null +++ b/pkg/cover/agentwatch.tpl @@ -0,0 +1,151 @@ +package coverdef + +import ( + "fmt" + "time" + "os" + "log" + "strconv" + "net/url" + + "{{.GlobalCoverVarImportPath}}/websocket" +) + +var ( + watchChannel = make(chan *blockInfo, 1024) + + watchEnabled = false + + waitDelay time.Duration = 10 * time.Second + host string = "{{.Host}}" +) + +func init() { + // init host + host_env := os.Getenv("GOC_CUSTOM_HOST") + if host_env != "" { + host = host_env + } + + var dialer = websocket.DefaultDialer + + go func() { + for { + // 获取进程元信息用于注册 + ps, err := getRegisterInfo() + if err != nil { + time.Sleep(waitDelay) + continue + } + + // 注册,直接将元信息放在 ws 地址中 + v := url.Values{} + v.Set("hostname", ps.hostname) + v.Set("pid", strconv.Itoa(ps.pid)) + v.Set("cmdline", ps.cmdline) + v.Encode() + + watchstreamUrl := fmt.Sprintf("ws://%v/v2/internal/ws/watchstream?%v", host, v.Encode()) + ws, _, err := dialer.Dial(watchstreamUrl, nil) + if err != nil { + log.Printf("[goc][Error] watch fail to dial to goc server: %v", err) + time.Sleep(waitDelay) + continue + } + + // 连接成功 + watchEnabled = true + log.Printf("[goc][Info] watch connected to goc server") + + ticker := time.NewTicker(time.Second) + closeFlag := false + go func() { + for { + // 必须调用一下以触发 ping 的自动处理 + _, _, err := ws.ReadMessage() + if err != nil { + break + } + } + closeFlag = true + }() + + Loop: + for { + select { + case block := <-watchChannel: + i := block.i + + cov := fmt.Sprintf("%s:%d.%d,%d.%d %d %d", block.name, + block.pos[3*i+0], uint16(block.pos[3*i+2]), + block.pos[3*i+1], uint16(block.pos[3*i+2] >> 16), + 1, + 0) + + err = ws.WriteMessage(websocket.TextMessage, []byte(cov)) + if err != nil { + watchEnabled = false + log.Println("[goc][Error] push coverage failed: %v", err) + time.Sleep(waitDelay) + break Loop + } + case <-ticker.C: + if closeFlag == true { + break Loop + } + } + } + } + }() +} + +// get process meta info for register +type processInfo struct { + hostname string + pid int + cmdline string +} + +func getRegisterInfo() (*processInfo, error) { + hostname, err := os.Hostname() + if err != nil { + log.Printf("[goc][Error] fail to get hostname: %v", hostname) + return nil, err + } + + pid := os.Getpid() + + cmdline := os.Args[0] + + return &processInfo{ + hostname: hostname, + pid: pid, + cmdline: cmdline, + }, nil +} + +// + +type blockInfo struct { + name string + pos []uint32 + i int +} + +// UploadCoverChangeEvent_{{.Random}} is non-blocking +func UploadCoverChangeEvent_{{.Random}}(name string, pos []uint32, i int) { + + if watchEnabled == false { + return + } + + // make sure send is non-blocking + select { + case watchChannel <- &blockInfo{ + name: name, + pos: pos, + i: i, + }: + default: + } +} diff --git a/pkg/cover/inject.go b/pkg/cover/inject.go index 0d7002b..b032f78 100644 --- a/pkg/cover/inject.go +++ b/pkg/cover/inject.go @@ -56,6 +56,12 @@ func Inject() { } // 在工程根目录注入所有插桩变量的声明+定义 injectGlobalCoverVarFile(allDecl) + // 在工程根目录注入 watch agent 的定义 + if config.GocConfig.Mode == "watch" { + log.Infof("watch mode is enabled") + injectWatchAgentFile() + log.Donef("watch handler injected") + } // 添加自定义 websocket 依赖 // 用户代码可能有 gorrila/websocket 的依赖,为避免依赖冲突,以及可能的 replace/vendor, // 这里直接注入一份完整的 gorrila/websocket 实现 @@ -81,7 +87,7 @@ func addCounters(pkg *config.Package) (*config.PackageCover, string) { decl := "" for file, coverVar := range coverVarMap { - decl += "\n" + tool.Annotate(filepath.Join(getPkgTmpDir(pkg.Dir), file), mode, coverVar.Var, gobalCoverVarImportPath) + "\n" + decl += "\n" + tool.Annotate(filepath.Join(getPkgTmpDir(pkg.Dir), file), mode, coverVar.Var, coverVar.File, gobalCoverVarImportPath) + "\n" } return &config.PackageCover{ @@ -120,6 +126,7 @@ func getPkgTmpDir(pkgDir string) string { // 使用 bridge.go 文件是为了避免插桩逻辑中的变量名污染 main 包 func injectGocAgent(where string, covers []*config.PackageCover) { injectPkgName := "goc-cover-agent-apis-auto-generated-11111-22222-package" + injectBridgeName := "goc-cover-agent-apis-auto-generated-11111-22222-bridge.go" wherePkg := filepath.Join(where, injectPkgName) err := os.MkdirAll(wherePkg, os.ModePerm) if err != nil { @@ -127,7 +134,7 @@ func injectGocAgent(where string, covers []*config.PackageCover) { } // create bridge file - whereBridge := filepath.Join(where, "goc-cover-agent-apis-auto-generated-11111-22222-bridge.go") + whereBridge := filepath.Join(where, injectBridgeName) f, err := os.Create(whereBridge) if err != nil { log.Fatalf("fail to create cover bridge file in temporary project: %v", err) @@ -153,6 +160,12 @@ func injectGocAgent(where string, covers []*config.PackageCover) { } defer f.Close() + var _coverMode string + if config.GocConfig.Mode == "watch" { + _coverMode = "cover" + } else { + _coverMode = config.GocConfig.Mode + } tmplData := struct { Covers []*config.PackageCover GlobalCoverVarImportPath string @@ -164,7 +177,7 @@ func injectGocAgent(where string, covers []*config.PackageCover) { GlobalCoverVarImportPath: config.GocConfig.GlobalCoverVarImportPath, Package: injectPkgName, Host: config.GocConfig.Host, - Mode: config.GocConfig.Mode, + Mode: _coverMode, } if err := coverMainTmpl.Execute(f, tmplData); err != nil { @@ -196,3 +209,27 @@ func injectGlobalCoverVarFile(decl string) { log.Fatalf("fail to write to global cover definition file: %v", err) } } + +func injectWatchAgentFile() { + globalCoverVarPackage := path.Base(config.GocConfig.GlobalCoverVarImportPath) + globalCoverDef := filepath.Join(config.GocConfig.TmpModProjectDir, globalCoverVarPackage) + + f, err := os.Create(filepath.Join(globalCoverDef, "watchagent.go")) + if err != nil { + log.Fatalf("fail to create watchagent file: %v", err) + } + + tmplData := struct { + Random string + Host string + GlobalCoverVarImportPath string + }{ + Random: filepath.Base(config.GocConfig.TmpModProjectDir), + Host: config.GocConfig.Host, + GlobalCoverVarImportPath: config.GocConfig.GlobalCoverVarImportPath, + } + + if err := coverWatchTmpl.Execute(f, tmplData); err != nil { + log.Fatalf("fail to generate watchagent in temporary project: %v", err) + } +} diff --git a/pkg/cover/internal/tool/cover.go b/pkg/cover/internal/tool/cover.go index 31f5337..fb2a93a 100644 --- a/pkg/cover/internal/tool/cover.go +++ b/pkg/cover/internal/tool/cover.go @@ -6,6 +6,8 @@ package tool import ( "bytes" + "path" + // "flag" "fmt" "go/ast" @@ -155,14 +157,16 @@ type Block struct { // File is a wrapper for the state of a file used in the parser. // The basic parse tree walker is a method of this type. type File struct { - fset *token.FileSet - name string // Name of file. - astFile *ast.File - blocks []Block - content []byte - edit *Buffer // QINIU - varVar string // QINIU - mode string // QINIU + fset *token.FileSet + name string // Name of file. + astFile *ast.File + blocks []Block + content []byte + edit *Buffer // QINIU + varVar string // QINIU + mode string // QINIU + importpathFileName string // QINIU, importpath + filename + random string // QINIU, random == tmp dir name } // findText finds text in the original source, starting at pos. @@ -304,7 +308,7 @@ func (f *File) Visit(node ast.Node) ast.Visitor { // 1. add cover variables into the original file // 2. return the cover variables declarations as plain string // original dec: func annotate(name string) { -func Annotate(name string, mode string, varVar string, globalCoverVarImportPath string) string { +func Annotate(name string, mode string, varVar string, importpathFilename string, globalCoverVarImportPath string) string { // QINIU switch mode { case "set": @@ -313,6 +317,8 @@ func Annotate(name string, mode string, varVar string, globalCoverVarImportPath counterStmt = incCounterStmt case "atomic": counterStmt = atomicCounterStmt + case "watch": + counterStmt = watchCounterStmt default: counterStmt = incCounterStmt } @@ -328,13 +334,15 @@ func Annotate(name string, mode string, varVar string, globalCoverVarImportPath } file := &File{ - fset: fset, - name: name, - content: content, - edit: NewBuffer(content), // QINIU - astFile: parsedFile, - varVar: varVar, - mode: mode, + fset: fset, + name: name, + content: content, + edit: NewBuffer(content), // QINIU + astFile: parsedFile, + varVar: varVar, // QINIU + mode: mode, // QINIU + importpathFileName: importpathFilename, // QINIU + random: path.Base(globalCoverVarImportPath), // QINIU } ast.Walk(file, file.astFile) @@ -409,6 +417,11 @@ func atomicCounterStmt(f *File, counter string) string { return fmt.Sprintf("%s.AddUint32(&%s, 1)", atomicPackageName, counter) } +// watchCounterStmt returns the expression: __count[23]++;UploadCoverChangeEvent(blockname, pos[:], index) +func watchCounterStmt(f *File, counter string) string { + return fmt.Sprintf("%s++; UploadCoverChangeEvent_%v(%s.BlockName, %s.Pos[:], %v)", counter, f.random, f.varVar, f.varVar, len(f.blocks)) +} + // QINIU // newCounter creates a new counter expression of the appropriate form. func (f *File) newCounter(start, end token.Pos, numStmt int) string { @@ -689,8 +702,12 @@ func (f *File) addVariables(w io.Writer) { fmt.Fprintf(w, "\tCount [%d]uint32\n", len(f.blocks)) fmt.Fprintf(w, "\tPos [3 * %d]uint32\n", len(f.blocks)) fmt.Fprintf(w, "\tNumStmt [%d]uint16\n", len(f.blocks)) + fmt.Fprintf(w, "\tBlockName string\n") // QINIU fmt.Fprintf(w, "} {\n") + // 写入 BlockName 初始化 + fmt.Fprintf(w, "\tBlockName: \"%v\",\n", f.importpathFileName) + // Initialize the position array field. fmt.Fprintf(w, "\tPos: [3 * %d]uint32{\n", len(f.blocks)) diff --git a/pkg/cover/template.go b/pkg/cover/template.go index bc39011..7e39a7c 100644 --- a/pkg/cover/template.go +++ b/pkg/cover/template.go @@ -19,3 +19,8 @@ var coverMainTmpl = template.Must(template.New("coverMain").Parse(coverMain)) //go:embed agent.tpl var coverMain string + +var coverWatchTmpl = template.Must(template.New("coverWatch").Parse(coverWatch)) + +//go:embed agentwatch.tpl +var coverWatch string diff --git a/pkg/server/api.go b/pkg/server/api.go index 81e49c9..8ae0850 100644 --- a/pkg/server/api.go +++ b/pkg/server/api.go @@ -12,11 +12,11 @@ import ( "k8s.io/test-infra/gopherage/pkg/cov" ) -// listServices return all service informations -func (gs *gocServer) listServices(c *gin.Context) { +// listAgents return all service informations +func (gs *gocServer) listAgents(c *gin.Context) { agents := make([]*gocCoveredAgent, 0) - gs.rpcClients.Range(func(key, value interface{}) bool { + gs.rpcAgents.Range(func(key, value interface{}) bool { agent, ok := value.(*gocCoveredAgent) if !ok { return false @@ -39,7 +39,7 @@ func (gs *gocServer) getProfiles(c *gin.Context) { mergedProfiles := make([][]*cover.Profile, 0) - gs.rpcClients.Range(func(key, value interface{}) bool { + gs.rpcAgents.Range(func(key, value interface{}) bool { agent, ok := value.(*gocCoveredAgent) if !ok { return false @@ -127,7 +127,7 @@ func (gs *gocServer) getProfiles(c *gin.Context) { // // it is async, the function will return immediately func (gs *gocServer) resetProfiles(c *gin.Context) { - gs.rpcClients.Range(func(key, value interface{}) bool { + gs.rpcAgents.Range(func(key, value interface{}) bool { agent, ok := value.(gocCoveredAgent) if !ok { return false @@ -149,3 +149,52 @@ func (gs *gocServer) resetProfiles(c *gin.Context) { return true }) } + +// watchProfileUpdate watch the profile change +// +// any profile change will be updated on this websocket connection. +func (gs *gocServer) watchProfileUpdate(c *gin.Context) { + // upgrade to websocket + ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Errorf("fail to establish websocket connection with watch client: %v", err) + c.JSON(http.StatusInternalServerError, nil) + } + + log.Infof("watch client connected") + + id := time.Now().String() + gwc := &gocWatchClient{ + ws: ws, + exitCh: make(chan int), + } + gs.watchClients.Store(id, gwc) + // send close msg and close ws connection + defer func() { + gs.watchClients.Delete(id) + ws.Close() + gwc.once.Do(func() { close(gwc.exitCh) }) + log.Infof("watch client disconnected") + }() + + // set pong handler + ws.SetReadDeadline(time.Now().Add(PongWait)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(PongWait)) + return nil + }) + + // set ping goroutine to ping every PingWait time + go func() { + ticker := time.NewTicker(PingWait) + defer ticker.Stop() + + for range ticker.C { + if err := gs.wsping(ws, PongWait); err != nil { + break + } + } + }() + + <-gwc.exitCh +} diff --git a/pkg/server/common.go b/pkg/server/common.go index 360ebd2..5243e39 100644 --- a/pkg/server/common.go +++ b/pkg/server/common.go @@ -16,6 +16,14 @@ type ProfileReq string type ProfileRes string +func (gs *gocServer) wsping(ws *websocket.Conn, deadline time.Duration) error { + return ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(deadline)) +} + +func (gs *gocServer) wsclose(ws *websocket.Conn, deadline time.Duration) error { + return ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(deadline)) +} + type ReadWriteCloser struct { ws *websocket.Conn r io.Reader diff --git a/pkg/server/rpcstream.go b/pkg/server/rpcstream.go index dca820f..9ef4227 100644 --- a/pkg/server/rpcstream.go +++ b/pkg/server/rpcstream.go @@ -9,7 +9,6 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "github.com/qiniu/goc/v2/pkg/log" ) @@ -27,16 +26,16 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) { if hostname == "" || pid == "" || cmdline == "" { c.JSON(http.StatusBadRequest, gin.H{ - "msg": "missing some param", + "msg": "missing some params", }) return } // 计算插桩服务 id - clientId := gs.generateClientId(remoteIP.String(), hostname, cmdline, pid) + agentId := gs.generateAgentId(remoteIP.String(), hostname, cmdline, pid) // 检查 id 是否重复 - if _, ok := gs.rpcClients.Load(clientId); ok { + if _, ok := gs.rpcAgents.Load(agentId); ok { c.JSON(http.StatusBadRequest, gin.H{ - "msg": "client already exist", + "msg": "the rpc agent already exists", }) return } @@ -52,7 +51,7 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) { // upgrade to websocket ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - log.Errorf("fail to establish websocket connection with client: %v", err) + log.Errorf("fail to establish websocket connection with rpc agent: %v", err) c.JSON(http.StatusInternalServerError, nil) } @@ -63,9 +62,9 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) { gs.wsclose(ws, deadline) time.Sleep(deadline) // 从维护的 websocket 链接字典中移除 - gs.rpcClients.Delete(clientId) + gs.rpcAgents.Delete(agentId) ws.Close() - log.Infof("close connection, %v", hostname) + log.Infof("close rpc connection, %v", hostname) }() // set pong handler @@ -82,7 +81,7 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) { for range ticker.C { if err := gs.wsping(ws, PongWait); err != nil { - log.Errorf("ping to %v failed: %v", hostname, err) + log.Errorf("rpc ping to %v failed: %v", hostname, err) break } } @@ -92,8 +91,8 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) { }) }() - log.Infof("one client established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), cmdline, pid, hostname) - // new rpc client + log.Infof("one rpc agent established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), cmdline, pid, hostname) + // new rpc agent // 在这里 websocket server 作为 rpc 的客户端, // 发送 rpc 请求, // 由被插桩服务返回 rpc 应答 @@ -101,22 +100,14 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) { codec := jsonrpc.NewClientCodec(rwc) gocA.rpc = rpc.NewClientWithCodec(codec) - gocA.Id = string(clientId) - gs.rpcClients.Store(clientId, gocA) + gocA.Id = string(agentId) + gs.rpcAgents.Store(agentId, gocA) // wait for exit <-gocA.exitCh } -func (gs *gocServer) wsping(ws *websocket.Conn, deadline time.Duration) error { - return ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(deadline)) -} - -func (gs *gocServer) wsclose(ws *websocket.Conn, deadline time.Duration) error { - return ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(deadline)) -} - -// generateClientId generate id based on client's meta infomation -func (gs *gocServer) generateClientId(args ...string) gocCliendId { +// generateAgentId generate id based on agent's meta infomation +func (gs *gocServer) generateAgentId(args ...string) gocCliendId { var path string for _, arg := range args { path += arg diff --git a/pkg/server/server.go b/pkg/server/server.go index 1163bdc..b43cfa0 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -1,6 +1,7 @@ package server import ( + "net/http" "net/rpc" "sync" "time" @@ -15,8 +16,10 @@ type gocServer struct { storePath string upgrader websocket.Upgrader - rpcClients sync.Map - // mu sync.Mutex // used to protect concurrent rpc call to agent + rpcAgents sync.Map + watchAgents sync.Map + watchCh chan []byte + watchClients sync.Map } type gocCliendId string @@ -34,6 +37,13 @@ type gocCoveredAgent struct { once sync.Once `json:"-"` // 保护 close(exitCh) 只执行一次 } +// api 客户端,不是 agent +type gocWatchClient struct { + ws *websocket.Conn + exitCh chan int + once sync.Once +} + func RunGocServerUntilExit(host string) { gs := gocServer{ storePath: "", @@ -41,7 +51,11 @@ func RunGocServerUntilExit(host string) { ReadBufferSize: 4096, WriteBufferSize: 4096, HandshakeTimeout: 45 * time.Second, + CheckOrigin: func(r *http.Request) bool { + return true + }, }, + watchCh: make(chan []byte), } r := gin.Default() @@ -49,14 +63,17 @@ func RunGocServerUntilExit(host string) { { v2.GET("/cover/profile", gs.getProfiles) v2.DELETE("/cover/profile", gs.resetProfiles) - v2.GET("/services", gs.listServices) + v2.GET("/rpcagents", gs.listAgents) + v2.GET("/watchagents", nil) - v2.GET("/cover/ws/watch", nil) + v2.GET("/cover/ws/watch", gs.watchProfileUpdate) // internal use only v2.GET("/internal/ws/rpcstream", gs.serveRpcStream) - v2.GET("/internal/ws/watchstream", nil) + v2.GET("/internal/ws/watchstream", gs.serveWatchInternalStream) } + go gs.watchLoop() + r.Run(host) } diff --git a/pkg/server/watchstream.go b/pkg/server/watchstream.go index 4ff461d..7d8a58b 100644 --- a/pkg/server/watchstream.go +++ b/pkg/server/watchstream.go @@ -1,9 +1,101 @@ package server -// GocWatchCoverArg defines client -> server arg -type GocWatchCoverArg struct { +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/qiniu/goc/v2/pkg/log" +) + +func (gs *gocServer) serveWatchInternalStream(c *gin.Context) { + // 检查插桩服务上报的信息 + remoteIP, _ := c.RemoteIP() + hostname := c.Query("hostname") + pid := c.Query("pid") + cmdline := c.Query("cmdline") + + if hostname == "" || pid == "" || cmdline == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "msg": "missing some params", + }) + return + } + // 计算插桩服务 id + agentId := gs.generateAgentId(remoteIP.String(), hostname, cmdline, pid) + // 检查 id 是否重复 + if _, ok := gs.watchAgents.Load(agentId); ok { + c.JSON(http.StatusBadRequest, gin.H{ + "msg": "the watch agent already exist", + }) + return + } + // upgrade to websocket + ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Errorf("fail to establish websocket connection with watch agent: %v", err) + c.JSON(http.StatusInternalServerError, nil) + } + + // send close msg and close ws connection + defer func() { + gs.watchAgents.Delete(agentId) + ws.Close() + log.Infof("close watch connection, %v", hostname) + }() + + // set pong handler + ws.SetReadDeadline(time.Now().Add(PongWait)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(PongWait)) + return nil + }) + + // set ping goroutine to ping every PingWait time + go func() { + ticker := time.NewTicker(PingWait) + defer ticker.Stop() + + for range ticker.C { + if err := gs.wsping(ws, PongWait); err != nil { + log.Errorf("watch ping to %v failed: %v", hostname, err) + break + } + } + }() + + log.Infof("one watch agent established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), cmdline, pid, hostname) + + for { + mt, message, err := ws.ReadMessage() + if err != nil { + log.Errorf("read from %v: %v", hostname, err) + break + } + if mt == websocket.TextMessage { + // 非阻塞写 + select { + case gs.watchCh <- message: + default: + } + } + } } -// GocWatchCoverReply defines client -> server reply -type GocWatchCoverReply struct { +func (gs *gocServer) watchLoop() { + for { + msg := <-gs.watchCh + gs.watchClients.Range(func(key, value interface{}) bool { + // 这里是客户端的 ws 连接,不是 agent ws 连接 + gwc := value.(*gocWatchClient) + err := gwc.ws.WriteMessage(websocket.TextMessage, msg) + if err != nil { + gwc.ws.Close() + gwc.once.Do(func() { close(gwc.exitCh) }) + } + + return true + }) + } }