add watch mode

This commit is contained in:
lyyyuna 2021-06-24 15:22:24 +08:00
parent 3e5ab72240
commit cf44927ce9
13 changed files with 429 additions and 62 deletions

View File

@ -14,7 +14,7 @@ var buildCmd = &cobra.Command{
} }
func init() { func init() {
buildCmd.Flags().StringVarP(&config.GocConfig.Mode, "mode", "", "count", "coverage mode: set, count, atomic") buildCmd.Flags().StringVarP(&config.GocConfig.Mode, "mode", "", "count", "coverage mode: set, count, atomic, watch")
buildCmd.Flags().StringVarP(&config.GocConfig.Host, "host", "", "127.0.0.1:7777", "specify the host of the goc sever") buildCmd.Flags().StringVarP(&config.GocConfig.Host, "host", "", "127.0.0.1:7777", "specify the host of the goc sever")
rootCmd.AddCommand(buildCmd) rootCmd.AddCommand(buildCmd)
} }

View File

@ -48,7 +48,7 @@ func tmpFolderName(path string) string {
sum := sha256.Sum256([]byte(path)) sum := sha256.Sum256([]byte(path))
h := fmt.Sprintf("%x", sum[:6]) h := fmt.Sprintf("%x", sum[:6])
return "goc-build-" + h return "gocbuild" + h
} }
// skipCopy skip copy .git dir and irregular files // skipCopy skip copy .git dir and irregular files

View File

@ -72,7 +72,7 @@ type PackageCover struct {
// FileVar holds the name of the generated coverage variables targeting the named file. // FileVar holds the name of the generated coverage variables targeting the named file.
type FileVar struct { type FileVar struct {
File string File string // 这里其实不是文件名,是 importpath + filename
Var string Var string
} }

View File

@ -53,11 +53,11 @@ func init() {
rpcstreamUrl := fmt.Sprintf("ws://%v/v2/internal/ws/rpcstream?%v", host, v.Encode()) rpcstreamUrl := fmt.Sprintf("ws://%v/v2/internal/ws/rpcstream?%v", host, v.Encode())
ws, _, err := dialer.Dial(rpcstreamUrl, nil) ws, _, err := dialer.Dial(rpcstreamUrl, nil)
if err != nil { if err != nil {
log.Printf("[goc][Error] fail to dial to goc server: %v", err) log.Printf("[goc][Error] rpc fail to dial to goc server: %v", err)
time.Sleep(waitDelay) time.Sleep(waitDelay)
continue continue
} }
log.Printf("[goc][Info] connected to goc server") log.Printf("[goc][Info] rpc connected to goc server")
rwc := &ReadWriteCloser{ws: ws} rwc := &ReadWriteCloser{ws: ws}
s := rpc.NewServer() s := rpc.NewServer()
@ -67,7 +67,7 @@ func init() {
// exit rpc server, close ws connection // exit rpc server, close ws connection
ws.Close() ws.Close()
time.Sleep(waitDelay) time.Sleep(waitDelay)
log.Printf("[goc][Error] connection to goc server broken", ) log.Printf("[goc][Error] rpc connection to goc server broken", )
} }
}() }()
} }

151
pkg/cover/agentwatch.tpl Normal file
View File

@ -0,0 +1,151 @@
package coverdef
import (
"fmt"
"time"
"os"
"log"
"strconv"
"net/url"
"{{.GlobalCoverVarImportPath}}/websocket"
)
var (
watchChannel = make(chan *blockInfo, 1024)
watchEnabled = false
waitDelay time.Duration = 10 * time.Second
host string = "{{.Host}}"
)
func init() {
// init host
host_env := os.Getenv("GOC_CUSTOM_HOST")
if host_env != "" {
host = host_env
}
var dialer = websocket.DefaultDialer
go func() {
for {
// 获取进程元信息用于注册
ps, err := getRegisterInfo()
if err != nil {
time.Sleep(waitDelay)
continue
}
// 注册,直接将元信息放在 ws 地址中
v := url.Values{}
v.Set("hostname", ps.hostname)
v.Set("pid", strconv.Itoa(ps.pid))
v.Set("cmdline", ps.cmdline)
v.Encode()
watchstreamUrl := fmt.Sprintf("ws://%v/v2/internal/ws/watchstream?%v", host, v.Encode())
ws, _, err := dialer.Dial(watchstreamUrl, nil)
if err != nil {
log.Printf("[goc][Error] watch fail to dial to goc server: %v", err)
time.Sleep(waitDelay)
continue
}
// 连接成功
watchEnabled = true
log.Printf("[goc][Info] watch connected to goc server")
ticker := time.NewTicker(time.Second)
closeFlag := false
go func() {
for {
// 必须调用一下以触发 ping 的自动处理
_, _, err := ws.ReadMessage()
if err != nil {
break
}
}
closeFlag = true
}()
Loop:
for {
select {
case block := <-watchChannel:
i := block.i
cov := fmt.Sprintf("%s:%d.%d,%d.%d %d %d", block.name,
block.pos[3*i+0], uint16(block.pos[3*i+2]),
block.pos[3*i+1], uint16(block.pos[3*i+2] >> 16),
1,
0)
err = ws.WriteMessage(websocket.TextMessage, []byte(cov))
if err != nil {
watchEnabled = false
log.Println("[goc][Error] push coverage failed: %v", err)
time.Sleep(waitDelay)
break Loop
}
case <-ticker.C:
if closeFlag == true {
break Loop
}
}
}
}
}()
}
// get process meta info for register
type processInfo struct {
hostname string
pid int
cmdline string
}
func getRegisterInfo() (*processInfo, error) {
hostname, err := os.Hostname()
if err != nil {
log.Printf("[goc][Error] fail to get hostname: %v", hostname)
return nil, err
}
pid := os.Getpid()
cmdline := os.Args[0]
return &processInfo{
hostname: hostname,
pid: pid,
cmdline: cmdline,
}, nil
}
//
type blockInfo struct {
name string
pos []uint32
i int
}
// UploadCoverChangeEvent_{{.Random}} is non-blocking
func UploadCoverChangeEvent_{{.Random}}(name string, pos []uint32, i int) {
if watchEnabled == false {
return
}
// make sure send is non-blocking
select {
case watchChannel <- &blockInfo{
name: name,
pos: pos,
i: i,
}:
default:
}
}

View File

@ -56,6 +56,12 @@ func Inject() {
} }
// 在工程根目录注入所有插桩变量的声明+定义 // 在工程根目录注入所有插桩变量的声明+定义
injectGlobalCoverVarFile(allDecl) injectGlobalCoverVarFile(allDecl)
// 在工程根目录注入 watch agent 的定义
if config.GocConfig.Mode == "watch" {
log.Infof("watch mode is enabled")
injectWatchAgentFile()
log.Donef("watch handler injected")
}
// 添加自定义 websocket 依赖 // 添加自定义 websocket 依赖
// 用户代码可能有 gorrila/websocket 的依赖,为避免依赖冲突,以及可能的 replace/vendor // 用户代码可能有 gorrila/websocket 的依赖,为避免依赖冲突,以及可能的 replace/vendor
// 这里直接注入一份完整的 gorrila/websocket 实现 // 这里直接注入一份完整的 gorrila/websocket 实现
@ -81,7 +87,7 @@ func addCounters(pkg *config.Package) (*config.PackageCover, string) {
decl := "" decl := ""
for file, coverVar := range coverVarMap { for file, coverVar := range coverVarMap {
decl += "\n" + tool.Annotate(filepath.Join(getPkgTmpDir(pkg.Dir), file), mode, coverVar.Var, gobalCoverVarImportPath) + "\n" decl += "\n" + tool.Annotate(filepath.Join(getPkgTmpDir(pkg.Dir), file), mode, coverVar.Var, coverVar.File, gobalCoverVarImportPath) + "\n"
} }
return &config.PackageCover{ return &config.PackageCover{
@ -120,6 +126,7 @@ func getPkgTmpDir(pkgDir string) string {
// 使用 bridge.go 文件是为了避免插桩逻辑中的变量名污染 main 包 // 使用 bridge.go 文件是为了避免插桩逻辑中的变量名污染 main 包
func injectGocAgent(where string, covers []*config.PackageCover) { func injectGocAgent(where string, covers []*config.PackageCover) {
injectPkgName := "goc-cover-agent-apis-auto-generated-11111-22222-package" injectPkgName := "goc-cover-agent-apis-auto-generated-11111-22222-package"
injectBridgeName := "goc-cover-agent-apis-auto-generated-11111-22222-bridge.go"
wherePkg := filepath.Join(where, injectPkgName) wherePkg := filepath.Join(where, injectPkgName)
err := os.MkdirAll(wherePkg, os.ModePerm) err := os.MkdirAll(wherePkg, os.ModePerm)
if err != nil { if err != nil {
@ -127,7 +134,7 @@ func injectGocAgent(where string, covers []*config.PackageCover) {
} }
// create bridge file // create bridge file
whereBridge := filepath.Join(where, "goc-cover-agent-apis-auto-generated-11111-22222-bridge.go") whereBridge := filepath.Join(where, injectBridgeName)
f, err := os.Create(whereBridge) f, err := os.Create(whereBridge)
if err != nil { if err != nil {
log.Fatalf("fail to create cover bridge file in temporary project: %v", err) log.Fatalf("fail to create cover bridge file in temporary project: %v", err)
@ -153,6 +160,12 @@ func injectGocAgent(where string, covers []*config.PackageCover) {
} }
defer f.Close() defer f.Close()
var _coverMode string
if config.GocConfig.Mode == "watch" {
_coverMode = "cover"
} else {
_coverMode = config.GocConfig.Mode
}
tmplData := struct { tmplData := struct {
Covers []*config.PackageCover Covers []*config.PackageCover
GlobalCoverVarImportPath string GlobalCoverVarImportPath string
@ -164,7 +177,7 @@ func injectGocAgent(where string, covers []*config.PackageCover) {
GlobalCoverVarImportPath: config.GocConfig.GlobalCoverVarImportPath, GlobalCoverVarImportPath: config.GocConfig.GlobalCoverVarImportPath,
Package: injectPkgName, Package: injectPkgName,
Host: config.GocConfig.Host, Host: config.GocConfig.Host,
Mode: config.GocConfig.Mode, Mode: _coverMode,
} }
if err := coverMainTmpl.Execute(f, tmplData); err != nil { if err := coverMainTmpl.Execute(f, tmplData); err != nil {
@ -196,3 +209,27 @@ func injectGlobalCoverVarFile(decl string) {
log.Fatalf("fail to write to global cover definition file: %v", err) log.Fatalf("fail to write to global cover definition file: %v", err)
} }
} }
func injectWatchAgentFile() {
globalCoverVarPackage := path.Base(config.GocConfig.GlobalCoverVarImportPath)
globalCoverDef := filepath.Join(config.GocConfig.TmpModProjectDir, globalCoverVarPackage)
f, err := os.Create(filepath.Join(globalCoverDef, "watchagent.go"))
if err != nil {
log.Fatalf("fail to create watchagent file: %v", err)
}
tmplData := struct {
Random string
Host string
GlobalCoverVarImportPath string
}{
Random: filepath.Base(config.GocConfig.TmpModProjectDir),
Host: config.GocConfig.Host,
GlobalCoverVarImportPath: config.GocConfig.GlobalCoverVarImportPath,
}
if err := coverWatchTmpl.Execute(f, tmplData); err != nil {
log.Fatalf("fail to generate watchagent in temporary project: %v", err)
}
}

View File

@ -6,6 +6,8 @@ package tool
import ( import (
"bytes" "bytes"
"path"
// "flag" // "flag"
"fmt" "fmt"
"go/ast" "go/ast"
@ -163,6 +165,8 @@ type File struct {
edit *Buffer // QINIU edit *Buffer // QINIU
varVar string // QINIU varVar string // QINIU
mode string // QINIU mode string // QINIU
importpathFileName string // QINIU, importpath + filename
random string // QINIU, random == tmp dir name
} }
// findText finds text in the original source, starting at pos. // findText finds text in the original source, starting at pos.
@ -304,7 +308,7 @@ func (f *File) Visit(node ast.Node) ast.Visitor {
// 1. add cover variables into the original file // 1. add cover variables into the original file
// 2. return the cover variables declarations as plain string // 2. return the cover variables declarations as plain string
// original dec: func annotate(name string) { // original dec: func annotate(name string) {
func Annotate(name string, mode string, varVar string, globalCoverVarImportPath string) string { func Annotate(name string, mode string, varVar string, importpathFilename string, globalCoverVarImportPath string) string {
// QINIU // QINIU
switch mode { switch mode {
case "set": case "set":
@ -313,6 +317,8 @@ func Annotate(name string, mode string, varVar string, globalCoverVarImportPath
counterStmt = incCounterStmt counterStmt = incCounterStmt
case "atomic": case "atomic":
counterStmt = atomicCounterStmt counterStmt = atomicCounterStmt
case "watch":
counterStmt = watchCounterStmt
default: default:
counterStmt = incCounterStmt counterStmt = incCounterStmt
} }
@ -333,8 +339,10 @@ func Annotate(name string, mode string, varVar string, globalCoverVarImportPath
content: content, content: content,
edit: NewBuffer(content), // QINIU edit: NewBuffer(content), // QINIU
astFile: parsedFile, astFile: parsedFile,
varVar: varVar, varVar: varVar, // QINIU
mode: mode, mode: mode, // QINIU
importpathFileName: importpathFilename, // QINIU
random: path.Base(globalCoverVarImportPath), // QINIU
} }
ast.Walk(file, file.astFile) ast.Walk(file, file.astFile)
@ -409,6 +417,11 @@ func atomicCounterStmt(f *File, counter string) string {
return fmt.Sprintf("%s.AddUint32(&%s, 1)", atomicPackageName, counter) return fmt.Sprintf("%s.AddUint32(&%s, 1)", atomicPackageName, counter)
} }
// watchCounterStmt returns the expression: __count[23]++;UploadCoverChangeEvent(blockname, pos[:], index)
func watchCounterStmt(f *File, counter string) string {
return fmt.Sprintf("%s++; UploadCoverChangeEvent_%v(%s.BlockName, %s.Pos[:], %v)", counter, f.random, f.varVar, f.varVar, len(f.blocks))
}
// QINIU // QINIU
// newCounter creates a new counter expression of the appropriate form. // newCounter creates a new counter expression of the appropriate form.
func (f *File) newCounter(start, end token.Pos, numStmt int) string { func (f *File) newCounter(start, end token.Pos, numStmt int) string {
@ -689,8 +702,12 @@ func (f *File) addVariables(w io.Writer) {
fmt.Fprintf(w, "\tCount [%d]uint32\n", len(f.blocks)) fmt.Fprintf(w, "\tCount [%d]uint32\n", len(f.blocks))
fmt.Fprintf(w, "\tPos [3 * %d]uint32\n", len(f.blocks)) fmt.Fprintf(w, "\tPos [3 * %d]uint32\n", len(f.blocks))
fmt.Fprintf(w, "\tNumStmt [%d]uint16\n", len(f.blocks)) fmt.Fprintf(w, "\tNumStmt [%d]uint16\n", len(f.blocks))
fmt.Fprintf(w, "\tBlockName string\n") // QINIU
fmt.Fprintf(w, "} {\n") fmt.Fprintf(w, "} {\n")
// 写入 BlockName 初始化
fmt.Fprintf(w, "\tBlockName: \"%v\",\n", f.importpathFileName)
// Initialize the position array field. // Initialize the position array field.
fmt.Fprintf(w, "\tPos: [3 * %d]uint32{\n", len(f.blocks)) fmt.Fprintf(w, "\tPos: [3 * %d]uint32{\n", len(f.blocks))

View File

@ -19,3 +19,8 @@ var coverMainTmpl = template.Must(template.New("coverMain").Parse(coverMain))
//go:embed agent.tpl //go:embed agent.tpl
var coverMain string var coverMain string
var coverWatchTmpl = template.Must(template.New("coverWatch").Parse(coverWatch))
//go:embed agentwatch.tpl
var coverWatch string

View File

@ -12,11 +12,11 @@ import (
"k8s.io/test-infra/gopherage/pkg/cov" "k8s.io/test-infra/gopherage/pkg/cov"
) )
// listServices return all service informations // listAgents return all service informations
func (gs *gocServer) listServices(c *gin.Context) { func (gs *gocServer) listAgents(c *gin.Context) {
agents := make([]*gocCoveredAgent, 0) agents := make([]*gocCoveredAgent, 0)
gs.rpcClients.Range(func(key, value interface{}) bool { gs.rpcAgents.Range(func(key, value interface{}) bool {
agent, ok := value.(*gocCoveredAgent) agent, ok := value.(*gocCoveredAgent)
if !ok { if !ok {
return false return false
@ -39,7 +39,7 @@ func (gs *gocServer) getProfiles(c *gin.Context) {
mergedProfiles := make([][]*cover.Profile, 0) mergedProfiles := make([][]*cover.Profile, 0)
gs.rpcClients.Range(func(key, value interface{}) bool { gs.rpcAgents.Range(func(key, value interface{}) bool {
agent, ok := value.(*gocCoveredAgent) agent, ok := value.(*gocCoveredAgent)
if !ok { if !ok {
return false return false
@ -127,7 +127,7 @@ func (gs *gocServer) getProfiles(c *gin.Context) {
// //
// it is async, the function will return immediately // it is async, the function will return immediately
func (gs *gocServer) resetProfiles(c *gin.Context) { func (gs *gocServer) resetProfiles(c *gin.Context) {
gs.rpcClients.Range(func(key, value interface{}) bool { gs.rpcAgents.Range(func(key, value interface{}) bool {
agent, ok := value.(gocCoveredAgent) agent, ok := value.(gocCoveredAgent)
if !ok { if !ok {
return false return false
@ -149,3 +149,52 @@ func (gs *gocServer) resetProfiles(c *gin.Context) {
return true return true
}) })
} }
// watchProfileUpdate watch the profile change
//
// any profile change will be updated on this websocket connection.
func (gs *gocServer) watchProfileUpdate(c *gin.Context) {
// upgrade to websocket
ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Errorf("fail to establish websocket connection with watch client: %v", err)
c.JSON(http.StatusInternalServerError, nil)
}
log.Infof("watch client connected")
id := time.Now().String()
gwc := &gocWatchClient{
ws: ws,
exitCh: make(chan int),
}
gs.watchClients.Store(id, gwc)
// send close msg and close ws connection
defer func() {
gs.watchClients.Delete(id)
ws.Close()
gwc.once.Do(func() { close(gwc.exitCh) })
log.Infof("watch client disconnected")
}()
// set pong handler
ws.SetReadDeadline(time.Now().Add(PongWait))
ws.SetPongHandler(func(string) error {
ws.SetReadDeadline(time.Now().Add(PongWait))
return nil
})
// set ping goroutine to ping every PingWait time
go func() {
ticker := time.NewTicker(PingWait)
defer ticker.Stop()
for range ticker.C {
if err := gs.wsping(ws, PongWait); err != nil {
break
}
}
}()
<-gwc.exitCh
}

View File

@ -16,6 +16,14 @@ type ProfileReq string
type ProfileRes string type ProfileRes string
func (gs *gocServer) wsping(ws *websocket.Conn, deadline time.Duration) error {
return ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(deadline))
}
func (gs *gocServer) wsclose(ws *websocket.Conn, deadline time.Duration) error {
return ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(deadline))
}
type ReadWriteCloser struct { type ReadWriteCloser struct {
ws *websocket.Conn ws *websocket.Conn
r io.Reader r io.Reader

View File

@ -9,7 +9,6 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/qiniu/goc/v2/pkg/log" "github.com/qiniu/goc/v2/pkg/log"
) )
@ -27,16 +26,16 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
if hostname == "" || pid == "" || cmdline == "" { if hostname == "" || pid == "" || cmdline == "" {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"msg": "missing some param", "msg": "missing some params",
}) })
return return
} }
// 计算插桩服务 id // 计算插桩服务 id
clientId := gs.generateClientId(remoteIP.String(), hostname, cmdline, pid) agentId := gs.generateAgentId(remoteIP.String(), hostname, cmdline, pid)
// 检查 id 是否重复 // 检查 id 是否重复
if _, ok := gs.rpcClients.Load(clientId); ok { if _, ok := gs.rpcAgents.Load(agentId); ok {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"msg": "client already exist", "msg": "the rpc agent already exists",
}) })
return return
} }
@ -52,7 +51,7 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
// upgrade to websocket // upgrade to websocket
ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil) ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil { if err != nil {
log.Errorf("fail to establish websocket connection with client: %v", err) log.Errorf("fail to establish websocket connection with rpc agent: %v", err)
c.JSON(http.StatusInternalServerError, nil) c.JSON(http.StatusInternalServerError, nil)
} }
@ -63,9 +62,9 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
gs.wsclose(ws, deadline) gs.wsclose(ws, deadline)
time.Sleep(deadline) time.Sleep(deadline)
// 从维护的 websocket 链接字典中移除 // 从维护的 websocket 链接字典中移除
gs.rpcClients.Delete(clientId) gs.rpcAgents.Delete(agentId)
ws.Close() ws.Close()
log.Infof("close connection, %v", hostname) log.Infof("close rpc connection, %v", hostname)
}() }()
// set pong handler // set pong handler
@ -82,7 +81,7 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
for range ticker.C { for range ticker.C {
if err := gs.wsping(ws, PongWait); err != nil { if err := gs.wsping(ws, PongWait); err != nil {
log.Errorf("ping to %v failed: %v", hostname, err) log.Errorf("rpc ping to %v failed: %v", hostname, err)
break break
} }
} }
@ -92,8 +91,8 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
}) })
}() }()
log.Infof("one client established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), cmdline, pid, hostname) log.Infof("one rpc agent established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), cmdline, pid, hostname)
// new rpc client // new rpc agent
// 在这里 websocket server 作为 rpc 的客户端, // 在这里 websocket server 作为 rpc 的客户端,
// 发送 rpc 请求, // 发送 rpc 请求,
// 由被插桩服务返回 rpc 应答 // 由被插桩服务返回 rpc 应答
@ -101,22 +100,14 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
codec := jsonrpc.NewClientCodec(rwc) codec := jsonrpc.NewClientCodec(rwc)
gocA.rpc = rpc.NewClientWithCodec(codec) gocA.rpc = rpc.NewClientWithCodec(codec)
gocA.Id = string(clientId) gocA.Id = string(agentId)
gs.rpcClients.Store(clientId, gocA) gs.rpcAgents.Store(agentId, gocA)
// wait for exit // wait for exit
<-gocA.exitCh <-gocA.exitCh
} }
func (gs *gocServer) wsping(ws *websocket.Conn, deadline time.Duration) error { // generateAgentId generate id based on agent's meta infomation
return ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(deadline)) func (gs *gocServer) generateAgentId(args ...string) gocCliendId {
}
func (gs *gocServer) wsclose(ws *websocket.Conn, deadline time.Duration) error {
return ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(deadline))
}
// generateClientId generate id based on client's meta infomation
func (gs *gocServer) generateClientId(args ...string) gocCliendId {
var path string var path string
for _, arg := range args { for _, arg := range args {
path += arg path += arg

View File

@ -1,6 +1,7 @@
package server package server
import ( import (
"net/http"
"net/rpc" "net/rpc"
"sync" "sync"
"time" "time"
@ -15,8 +16,10 @@ type gocServer struct {
storePath string storePath string
upgrader websocket.Upgrader upgrader websocket.Upgrader
rpcClients sync.Map rpcAgents sync.Map
// mu sync.Mutex // used to protect concurrent rpc call to agent watchAgents sync.Map
watchCh chan []byte
watchClients sync.Map
} }
type gocCliendId string type gocCliendId string
@ -34,6 +37,13 @@ type gocCoveredAgent struct {
once sync.Once `json:"-"` // 保护 close(exitCh) 只执行一次 once sync.Once `json:"-"` // 保护 close(exitCh) 只执行一次
} }
// api 客户端,不是 agent
type gocWatchClient struct {
ws *websocket.Conn
exitCh chan int
once sync.Once
}
func RunGocServerUntilExit(host string) { func RunGocServerUntilExit(host string) {
gs := gocServer{ gs := gocServer{
storePath: "", storePath: "",
@ -41,7 +51,11 @@ func RunGocServerUntilExit(host string) {
ReadBufferSize: 4096, ReadBufferSize: 4096,
WriteBufferSize: 4096, WriteBufferSize: 4096,
HandshakeTimeout: 45 * time.Second, HandshakeTimeout: 45 * time.Second,
CheckOrigin: func(r *http.Request) bool {
return true
}, },
},
watchCh: make(chan []byte),
} }
r := gin.Default() r := gin.Default()
@ -49,14 +63,17 @@ func RunGocServerUntilExit(host string) {
{ {
v2.GET("/cover/profile", gs.getProfiles) v2.GET("/cover/profile", gs.getProfiles)
v2.DELETE("/cover/profile", gs.resetProfiles) v2.DELETE("/cover/profile", gs.resetProfiles)
v2.GET("/services", gs.listServices) v2.GET("/rpcagents", gs.listAgents)
v2.GET("/watchagents", nil)
v2.GET("/cover/ws/watch", nil) v2.GET("/cover/ws/watch", gs.watchProfileUpdate)
// internal use only // internal use only
v2.GET("/internal/ws/rpcstream", gs.serveRpcStream) v2.GET("/internal/ws/rpcstream", gs.serveRpcStream)
v2.GET("/internal/ws/watchstream", nil) v2.GET("/internal/ws/watchstream", gs.serveWatchInternalStream)
} }
go gs.watchLoop()
r.Run(host) r.Run(host)
} }

View File

@ -1,9 +1,101 @@
package server package server
// GocWatchCoverArg defines client -> server arg import (
type GocWatchCoverArg struct { "net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/qiniu/goc/v2/pkg/log"
)
func (gs *gocServer) serveWatchInternalStream(c *gin.Context) {
// 检查插桩服务上报的信息
remoteIP, _ := c.RemoteIP()
hostname := c.Query("hostname")
pid := c.Query("pid")
cmdline := c.Query("cmdline")
if hostname == "" || pid == "" || cmdline == "" {
c.JSON(http.StatusBadRequest, gin.H{
"msg": "missing some params",
})
return
}
// 计算插桩服务 id
agentId := gs.generateAgentId(remoteIP.String(), hostname, cmdline, pid)
// 检查 id 是否重复
if _, ok := gs.watchAgents.Load(agentId); ok {
c.JSON(http.StatusBadRequest, gin.H{
"msg": "the watch agent already exist",
})
return
}
// upgrade to websocket
ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Errorf("fail to establish websocket connection with watch agent: %v", err)
c.JSON(http.StatusInternalServerError, nil)
}
// send close msg and close ws connection
defer func() {
gs.watchAgents.Delete(agentId)
ws.Close()
log.Infof("close watch connection, %v", hostname)
}()
// set pong handler
ws.SetReadDeadline(time.Now().Add(PongWait))
ws.SetPongHandler(func(string) error {
ws.SetReadDeadline(time.Now().Add(PongWait))
return nil
})
// set ping goroutine to ping every PingWait time
go func() {
ticker := time.NewTicker(PingWait)
defer ticker.Stop()
for range ticker.C {
if err := gs.wsping(ws, PongWait); err != nil {
log.Errorf("watch ping to %v failed: %v", hostname, err)
break
}
}
}()
log.Infof("one watch agent established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), cmdline, pid, hostname)
for {
mt, message, err := ws.ReadMessage()
if err != nil {
log.Errorf("read from %v: %v", hostname, err)
break
}
if mt == websocket.TextMessage {
// 非阻塞写
select {
case gs.watchCh <- message:
default:
}
}
}
} }
// GocWatchCoverReply defines client -> server reply func (gs *gocServer) watchLoop() {
type GocWatchCoverReply struct { for {
msg := <-gs.watchCh
gs.watchClients.Range(func(key, value interface{}) bool {
// 这里是客户端的 ws 连接,不是 agent ws 连接
gwc := value.(*gocWatchClient)
err := gwc.ws.WriteMessage(websocket.TextMessage, msg)
if err != nil {
gwc.ws.Close()
gwc.once.Do(func() { close(gwc.exitCh) })
}
return true
})
}
} }