add watch mode
This commit is contained in:
parent
3e5ab72240
commit
cf44927ce9
@ -14,7 +14,7 @@ var buildCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
func init() {
|
||||
buildCmd.Flags().StringVarP(&config.GocConfig.Mode, "mode", "", "count", "coverage mode: set, count, atomic")
|
||||
buildCmd.Flags().StringVarP(&config.GocConfig.Mode, "mode", "", "count", "coverage mode: set, count, atomic, watch")
|
||||
buildCmd.Flags().StringVarP(&config.GocConfig.Host, "host", "", "127.0.0.1:7777", "specify the host of the goc sever")
|
||||
rootCmd.AddCommand(buildCmd)
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ func tmpFolderName(path string) string {
|
||||
sum := sha256.Sum256([]byte(path))
|
||||
h := fmt.Sprintf("%x", sum[:6])
|
||||
|
||||
return "goc-build-" + h
|
||||
return "gocbuild" + h
|
||||
}
|
||||
|
||||
// skipCopy skip copy .git dir and irregular files
|
||||
|
@ -72,7 +72,7 @@ type PackageCover struct {
|
||||
|
||||
// FileVar holds the name of the generated coverage variables targeting the named file.
|
||||
type FileVar struct {
|
||||
File string
|
||||
File string // 这里其实不是文件名,是 importpath + filename
|
||||
Var string
|
||||
}
|
||||
|
||||
|
@ -53,11 +53,11 @@ func init() {
|
||||
rpcstreamUrl := fmt.Sprintf("ws://%v/v2/internal/ws/rpcstream?%v", host, v.Encode())
|
||||
ws, _, err := dialer.Dial(rpcstreamUrl, nil)
|
||||
if err != nil {
|
||||
log.Printf("[goc][Error] fail to dial to goc server: %v", err)
|
||||
log.Printf("[goc][Error] rpc fail to dial to goc server: %v", err)
|
||||
time.Sleep(waitDelay)
|
||||
continue
|
||||
}
|
||||
log.Printf("[goc][Info] connected to goc server")
|
||||
log.Printf("[goc][Info] rpc connected to goc server")
|
||||
|
||||
rwc := &ReadWriteCloser{ws: ws}
|
||||
s := rpc.NewServer()
|
||||
@ -67,7 +67,7 @@ func init() {
|
||||
// exit rpc server, close ws connection
|
||||
ws.Close()
|
||||
time.Sleep(waitDelay)
|
||||
log.Printf("[goc][Error] connection to goc server broken", )
|
||||
log.Printf("[goc][Error] rpc connection to goc server broken", )
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
151
pkg/cover/agentwatch.tpl
Normal file
151
pkg/cover/agentwatch.tpl
Normal file
@ -0,0 +1,151 @@
|
||||
package coverdef
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
"os"
|
||||
"log"
|
||||
"strconv"
|
||||
"net/url"
|
||||
|
||||
"{{.GlobalCoverVarImportPath}}/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
watchChannel = make(chan *blockInfo, 1024)
|
||||
|
||||
watchEnabled = false
|
||||
|
||||
waitDelay time.Duration = 10 * time.Second
|
||||
host string = "{{.Host}}"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// init host
|
||||
host_env := os.Getenv("GOC_CUSTOM_HOST")
|
||||
if host_env != "" {
|
||||
host = host_env
|
||||
}
|
||||
|
||||
var dialer = websocket.DefaultDialer
|
||||
|
||||
go func() {
|
||||
for {
|
||||
// 获取进程元信息用于注册
|
||||
ps, err := getRegisterInfo()
|
||||
if err != nil {
|
||||
time.Sleep(waitDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
// 注册,直接将元信息放在 ws 地址中
|
||||
v := url.Values{}
|
||||
v.Set("hostname", ps.hostname)
|
||||
v.Set("pid", strconv.Itoa(ps.pid))
|
||||
v.Set("cmdline", ps.cmdline)
|
||||
v.Encode()
|
||||
|
||||
watchstreamUrl := fmt.Sprintf("ws://%v/v2/internal/ws/watchstream?%v", host, v.Encode())
|
||||
ws, _, err := dialer.Dial(watchstreamUrl, nil)
|
||||
if err != nil {
|
||||
log.Printf("[goc][Error] watch fail to dial to goc server: %v", err)
|
||||
time.Sleep(waitDelay)
|
||||
continue
|
||||
}
|
||||
|
||||
// 连接成功
|
||||
watchEnabled = true
|
||||
log.Printf("[goc][Info] watch connected to goc server")
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
closeFlag := false
|
||||
go func() {
|
||||
for {
|
||||
// 必须调用一下以触发 ping 的自动处理
|
||||
_, _, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
closeFlag = true
|
||||
}()
|
||||
|
||||
Loop:
|
||||
for {
|
||||
select {
|
||||
case block := <-watchChannel:
|
||||
i := block.i
|
||||
|
||||
cov := fmt.Sprintf("%s:%d.%d,%d.%d %d %d", block.name,
|
||||
block.pos[3*i+0], uint16(block.pos[3*i+2]),
|
||||
block.pos[3*i+1], uint16(block.pos[3*i+2] >> 16),
|
||||
1,
|
||||
0)
|
||||
|
||||
err = ws.WriteMessage(websocket.TextMessage, []byte(cov))
|
||||
if err != nil {
|
||||
watchEnabled = false
|
||||
log.Println("[goc][Error] push coverage failed: %v", err)
|
||||
time.Sleep(waitDelay)
|
||||
break Loop
|
||||
}
|
||||
case <-ticker.C:
|
||||
if closeFlag == true {
|
||||
break Loop
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// get process meta info for register
|
||||
type processInfo struct {
|
||||
hostname string
|
||||
pid int
|
||||
cmdline string
|
||||
}
|
||||
|
||||
func getRegisterInfo() (*processInfo, error) {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
log.Printf("[goc][Error] fail to get hostname: %v", hostname)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pid := os.Getpid()
|
||||
|
||||
cmdline := os.Args[0]
|
||||
|
||||
return &processInfo{
|
||||
hostname: hostname,
|
||||
pid: pid,
|
||||
cmdline: cmdline,
|
||||
}, nil
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
type blockInfo struct {
|
||||
name string
|
||||
pos []uint32
|
||||
i int
|
||||
}
|
||||
|
||||
// UploadCoverChangeEvent_{{.Random}} is non-blocking
|
||||
func UploadCoverChangeEvent_{{.Random}}(name string, pos []uint32, i int) {
|
||||
|
||||
if watchEnabled == false {
|
||||
return
|
||||
}
|
||||
|
||||
// make sure send is non-blocking
|
||||
select {
|
||||
case watchChannel <- &blockInfo{
|
||||
name: name,
|
||||
pos: pos,
|
||||
i: i,
|
||||
}:
|
||||
default:
|
||||
}
|
||||
}
|
@ -56,6 +56,12 @@ func Inject() {
|
||||
}
|
||||
// 在工程根目录注入所有插桩变量的声明+定义
|
||||
injectGlobalCoverVarFile(allDecl)
|
||||
// 在工程根目录注入 watch agent 的定义
|
||||
if config.GocConfig.Mode == "watch" {
|
||||
log.Infof("watch mode is enabled")
|
||||
injectWatchAgentFile()
|
||||
log.Donef("watch handler injected")
|
||||
}
|
||||
// 添加自定义 websocket 依赖
|
||||
// 用户代码可能有 gorrila/websocket 的依赖,为避免依赖冲突,以及可能的 replace/vendor,
|
||||
// 这里直接注入一份完整的 gorrila/websocket 实现
|
||||
@ -81,7 +87,7 @@ func addCounters(pkg *config.Package) (*config.PackageCover, string) {
|
||||
|
||||
decl := ""
|
||||
for file, coverVar := range coverVarMap {
|
||||
decl += "\n" + tool.Annotate(filepath.Join(getPkgTmpDir(pkg.Dir), file), mode, coverVar.Var, gobalCoverVarImportPath) + "\n"
|
||||
decl += "\n" + tool.Annotate(filepath.Join(getPkgTmpDir(pkg.Dir), file), mode, coverVar.Var, coverVar.File, gobalCoverVarImportPath) + "\n"
|
||||
}
|
||||
|
||||
return &config.PackageCover{
|
||||
@ -120,6 +126,7 @@ func getPkgTmpDir(pkgDir string) string {
|
||||
// 使用 bridge.go 文件是为了避免插桩逻辑中的变量名污染 main 包
|
||||
func injectGocAgent(where string, covers []*config.PackageCover) {
|
||||
injectPkgName := "goc-cover-agent-apis-auto-generated-11111-22222-package"
|
||||
injectBridgeName := "goc-cover-agent-apis-auto-generated-11111-22222-bridge.go"
|
||||
wherePkg := filepath.Join(where, injectPkgName)
|
||||
err := os.MkdirAll(wherePkg, os.ModePerm)
|
||||
if err != nil {
|
||||
@ -127,7 +134,7 @@ func injectGocAgent(where string, covers []*config.PackageCover) {
|
||||
}
|
||||
|
||||
// create bridge file
|
||||
whereBridge := filepath.Join(where, "goc-cover-agent-apis-auto-generated-11111-22222-bridge.go")
|
||||
whereBridge := filepath.Join(where, injectBridgeName)
|
||||
f, err := os.Create(whereBridge)
|
||||
if err != nil {
|
||||
log.Fatalf("fail to create cover bridge file in temporary project: %v", err)
|
||||
@ -153,6 +160,12 @@ func injectGocAgent(where string, covers []*config.PackageCover) {
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var _coverMode string
|
||||
if config.GocConfig.Mode == "watch" {
|
||||
_coverMode = "cover"
|
||||
} else {
|
||||
_coverMode = config.GocConfig.Mode
|
||||
}
|
||||
tmplData := struct {
|
||||
Covers []*config.PackageCover
|
||||
GlobalCoverVarImportPath string
|
||||
@ -164,7 +177,7 @@ func injectGocAgent(where string, covers []*config.PackageCover) {
|
||||
GlobalCoverVarImportPath: config.GocConfig.GlobalCoverVarImportPath,
|
||||
Package: injectPkgName,
|
||||
Host: config.GocConfig.Host,
|
||||
Mode: config.GocConfig.Mode,
|
||||
Mode: _coverMode,
|
||||
}
|
||||
|
||||
if err := coverMainTmpl.Execute(f, tmplData); err != nil {
|
||||
@ -196,3 +209,27 @@ func injectGlobalCoverVarFile(decl string) {
|
||||
log.Fatalf("fail to write to global cover definition file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func injectWatchAgentFile() {
|
||||
globalCoverVarPackage := path.Base(config.GocConfig.GlobalCoverVarImportPath)
|
||||
globalCoverDef := filepath.Join(config.GocConfig.TmpModProjectDir, globalCoverVarPackage)
|
||||
|
||||
f, err := os.Create(filepath.Join(globalCoverDef, "watchagent.go"))
|
||||
if err != nil {
|
||||
log.Fatalf("fail to create watchagent file: %v", err)
|
||||
}
|
||||
|
||||
tmplData := struct {
|
||||
Random string
|
||||
Host string
|
||||
GlobalCoverVarImportPath string
|
||||
}{
|
||||
Random: filepath.Base(config.GocConfig.TmpModProjectDir),
|
||||
Host: config.GocConfig.Host,
|
||||
GlobalCoverVarImportPath: config.GocConfig.GlobalCoverVarImportPath,
|
||||
}
|
||||
|
||||
if err := coverWatchTmpl.Execute(f, tmplData); err != nil {
|
||||
log.Fatalf("fail to generate watchagent in temporary project: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -6,6 +6,8 @@ package tool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"path"
|
||||
|
||||
// "flag"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
@ -155,14 +157,16 @@ type Block struct {
|
||||
// File is a wrapper for the state of a file used in the parser.
|
||||
// The basic parse tree walker is a method of this type.
|
||||
type File struct {
|
||||
fset *token.FileSet
|
||||
name string // Name of file.
|
||||
astFile *ast.File
|
||||
blocks []Block
|
||||
content []byte
|
||||
edit *Buffer // QINIU
|
||||
varVar string // QINIU
|
||||
mode string // QINIU
|
||||
fset *token.FileSet
|
||||
name string // Name of file.
|
||||
astFile *ast.File
|
||||
blocks []Block
|
||||
content []byte
|
||||
edit *Buffer // QINIU
|
||||
varVar string // QINIU
|
||||
mode string // QINIU
|
||||
importpathFileName string // QINIU, importpath + filename
|
||||
random string // QINIU, random == tmp dir name
|
||||
}
|
||||
|
||||
// findText finds text in the original source, starting at pos.
|
||||
@ -304,7 +308,7 @@ func (f *File) Visit(node ast.Node) ast.Visitor {
|
||||
// 1. add cover variables into the original file
|
||||
// 2. return the cover variables declarations as plain string
|
||||
// original dec: func annotate(name string) {
|
||||
func Annotate(name string, mode string, varVar string, globalCoverVarImportPath string) string {
|
||||
func Annotate(name string, mode string, varVar string, importpathFilename string, globalCoverVarImportPath string) string {
|
||||
// QINIU
|
||||
switch mode {
|
||||
case "set":
|
||||
@ -313,6 +317,8 @@ func Annotate(name string, mode string, varVar string, globalCoverVarImportPath
|
||||
counterStmt = incCounterStmt
|
||||
case "atomic":
|
||||
counterStmt = atomicCounterStmt
|
||||
case "watch":
|
||||
counterStmt = watchCounterStmt
|
||||
default:
|
||||
counterStmt = incCounterStmt
|
||||
}
|
||||
@ -328,13 +334,15 @@ func Annotate(name string, mode string, varVar string, globalCoverVarImportPath
|
||||
}
|
||||
|
||||
file := &File{
|
||||
fset: fset,
|
||||
name: name,
|
||||
content: content,
|
||||
edit: NewBuffer(content), // QINIU
|
||||
astFile: parsedFile,
|
||||
varVar: varVar,
|
||||
mode: mode,
|
||||
fset: fset,
|
||||
name: name,
|
||||
content: content,
|
||||
edit: NewBuffer(content), // QINIU
|
||||
astFile: parsedFile,
|
||||
varVar: varVar, // QINIU
|
||||
mode: mode, // QINIU
|
||||
importpathFileName: importpathFilename, // QINIU
|
||||
random: path.Base(globalCoverVarImportPath), // QINIU
|
||||
}
|
||||
|
||||
ast.Walk(file, file.astFile)
|
||||
@ -409,6 +417,11 @@ func atomicCounterStmt(f *File, counter string) string {
|
||||
return fmt.Sprintf("%s.AddUint32(&%s, 1)", atomicPackageName, counter)
|
||||
}
|
||||
|
||||
// watchCounterStmt returns the expression: __count[23]++;UploadCoverChangeEvent(blockname, pos[:], index)
|
||||
func watchCounterStmt(f *File, counter string) string {
|
||||
return fmt.Sprintf("%s++; UploadCoverChangeEvent_%v(%s.BlockName, %s.Pos[:], %v)", counter, f.random, f.varVar, f.varVar, len(f.blocks))
|
||||
}
|
||||
|
||||
// QINIU
|
||||
// newCounter creates a new counter expression of the appropriate form.
|
||||
func (f *File) newCounter(start, end token.Pos, numStmt int) string {
|
||||
@ -689,8 +702,12 @@ func (f *File) addVariables(w io.Writer) {
|
||||
fmt.Fprintf(w, "\tCount [%d]uint32\n", len(f.blocks))
|
||||
fmt.Fprintf(w, "\tPos [3 * %d]uint32\n", len(f.blocks))
|
||||
fmt.Fprintf(w, "\tNumStmt [%d]uint16\n", len(f.blocks))
|
||||
fmt.Fprintf(w, "\tBlockName string\n") // QINIU
|
||||
fmt.Fprintf(w, "} {\n")
|
||||
|
||||
// 写入 BlockName 初始化
|
||||
fmt.Fprintf(w, "\tBlockName: \"%v\",\n", f.importpathFileName)
|
||||
|
||||
// Initialize the position array field.
|
||||
fmt.Fprintf(w, "\tPos: [3 * %d]uint32{\n", len(f.blocks))
|
||||
|
||||
|
@ -19,3 +19,8 @@ var coverMainTmpl = template.Must(template.New("coverMain").Parse(coverMain))
|
||||
|
||||
//go:embed agent.tpl
|
||||
var coverMain string
|
||||
|
||||
var coverWatchTmpl = template.Must(template.New("coverWatch").Parse(coverWatch))
|
||||
|
||||
//go:embed agentwatch.tpl
|
||||
var coverWatch string
|
||||
|
@ -12,11 +12,11 @@ import (
|
||||
"k8s.io/test-infra/gopherage/pkg/cov"
|
||||
)
|
||||
|
||||
// listServices return all service informations
|
||||
func (gs *gocServer) listServices(c *gin.Context) {
|
||||
// listAgents return all service informations
|
||||
func (gs *gocServer) listAgents(c *gin.Context) {
|
||||
agents := make([]*gocCoveredAgent, 0)
|
||||
|
||||
gs.rpcClients.Range(func(key, value interface{}) bool {
|
||||
gs.rpcAgents.Range(func(key, value interface{}) bool {
|
||||
agent, ok := value.(*gocCoveredAgent)
|
||||
if !ok {
|
||||
return false
|
||||
@ -39,7 +39,7 @@ func (gs *gocServer) getProfiles(c *gin.Context) {
|
||||
|
||||
mergedProfiles := make([][]*cover.Profile, 0)
|
||||
|
||||
gs.rpcClients.Range(func(key, value interface{}) bool {
|
||||
gs.rpcAgents.Range(func(key, value interface{}) bool {
|
||||
agent, ok := value.(*gocCoveredAgent)
|
||||
if !ok {
|
||||
return false
|
||||
@ -127,7 +127,7 @@ func (gs *gocServer) getProfiles(c *gin.Context) {
|
||||
//
|
||||
// it is async, the function will return immediately
|
||||
func (gs *gocServer) resetProfiles(c *gin.Context) {
|
||||
gs.rpcClients.Range(func(key, value interface{}) bool {
|
||||
gs.rpcAgents.Range(func(key, value interface{}) bool {
|
||||
agent, ok := value.(gocCoveredAgent)
|
||||
if !ok {
|
||||
return false
|
||||
@ -149,3 +149,52 @@ func (gs *gocServer) resetProfiles(c *gin.Context) {
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// watchProfileUpdate watch the profile change
|
||||
//
|
||||
// any profile change will be updated on this websocket connection.
|
||||
func (gs *gocServer) watchProfileUpdate(c *gin.Context) {
|
||||
// upgrade to websocket
|
||||
ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Errorf("fail to establish websocket connection with watch client: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, nil)
|
||||
}
|
||||
|
||||
log.Infof("watch client connected")
|
||||
|
||||
id := time.Now().String()
|
||||
gwc := &gocWatchClient{
|
||||
ws: ws,
|
||||
exitCh: make(chan int),
|
||||
}
|
||||
gs.watchClients.Store(id, gwc)
|
||||
// send close msg and close ws connection
|
||||
defer func() {
|
||||
gs.watchClients.Delete(id)
|
||||
ws.Close()
|
||||
gwc.once.Do(func() { close(gwc.exitCh) })
|
||||
log.Infof("watch client disconnected")
|
||||
}()
|
||||
|
||||
// set pong handler
|
||||
ws.SetReadDeadline(time.Now().Add(PongWait))
|
||||
ws.SetPongHandler(func(string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(PongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
// set ping goroutine to ping every PingWait time
|
||||
go func() {
|
||||
ticker := time.NewTicker(PingWait)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := gs.wsping(ws, PongWait); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
<-gwc.exitCh
|
||||
}
|
||||
|
@ -16,6 +16,14 @@ type ProfileReq string
|
||||
|
||||
type ProfileRes string
|
||||
|
||||
func (gs *gocServer) wsping(ws *websocket.Conn, deadline time.Duration) error {
|
||||
return ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(deadline))
|
||||
}
|
||||
|
||||
func (gs *gocServer) wsclose(ws *websocket.Conn, deadline time.Duration) error {
|
||||
return ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(deadline))
|
||||
}
|
||||
|
||||
type ReadWriteCloser struct {
|
||||
ws *websocket.Conn
|
||||
r io.Reader
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/qiniu/goc/v2/pkg/log"
|
||||
)
|
||||
|
||||
@ -27,16 +26,16 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
|
||||
|
||||
if hostname == "" || pid == "" || cmdline == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"msg": "missing some param",
|
||||
"msg": "missing some params",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 计算插桩服务 id
|
||||
clientId := gs.generateClientId(remoteIP.String(), hostname, cmdline, pid)
|
||||
agentId := gs.generateAgentId(remoteIP.String(), hostname, cmdline, pid)
|
||||
// 检查 id 是否重复
|
||||
if _, ok := gs.rpcClients.Load(clientId); ok {
|
||||
if _, ok := gs.rpcAgents.Load(agentId); ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"msg": "client already exist",
|
||||
"msg": "the rpc agent already exists",
|
||||
})
|
||||
return
|
||||
}
|
||||
@ -52,7 +51,7 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
|
||||
// upgrade to websocket
|
||||
ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Errorf("fail to establish websocket connection with client: %v", err)
|
||||
log.Errorf("fail to establish websocket connection with rpc agent: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, nil)
|
||||
}
|
||||
|
||||
@ -63,9 +62,9 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
|
||||
gs.wsclose(ws, deadline)
|
||||
time.Sleep(deadline)
|
||||
// 从维护的 websocket 链接字典中移除
|
||||
gs.rpcClients.Delete(clientId)
|
||||
gs.rpcAgents.Delete(agentId)
|
||||
ws.Close()
|
||||
log.Infof("close connection, %v", hostname)
|
||||
log.Infof("close rpc connection, %v", hostname)
|
||||
}()
|
||||
|
||||
// set pong handler
|
||||
@ -82,7 +81,7 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
|
||||
|
||||
for range ticker.C {
|
||||
if err := gs.wsping(ws, PongWait); err != nil {
|
||||
log.Errorf("ping to %v failed: %v", hostname, err)
|
||||
log.Errorf("rpc ping to %v failed: %v", hostname, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
@ -92,8 +91,8 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
|
||||
})
|
||||
}()
|
||||
|
||||
log.Infof("one client established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), cmdline, pid, hostname)
|
||||
// new rpc client
|
||||
log.Infof("one rpc agent established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), cmdline, pid, hostname)
|
||||
// new rpc agent
|
||||
// 在这里 websocket server 作为 rpc 的客户端,
|
||||
// 发送 rpc 请求,
|
||||
// 由被插桩服务返回 rpc 应答
|
||||
@ -101,22 +100,14 @@ func (gs *gocServer) serveRpcStream(c *gin.Context) {
|
||||
codec := jsonrpc.NewClientCodec(rwc)
|
||||
|
||||
gocA.rpc = rpc.NewClientWithCodec(codec)
|
||||
gocA.Id = string(clientId)
|
||||
gs.rpcClients.Store(clientId, gocA)
|
||||
gocA.Id = string(agentId)
|
||||
gs.rpcAgents.Store(agentId, gocA)
|
||||
// wait for exit
|
||||
<-gocA.exitCh
|
||||
}
|
||||
|
||||
func (gs *gocServer) wsping(ws *websocket.Conn, deadline time.Duration) error {
|
||||
return ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(deadline))
|
||||
}
|
||||
|
||||
func (gs *gocServer) wsclose(ws *websocket.Conn, deadline time.Duration) error {
|
||||
return ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(deadline))
|
||||
}
|
||||
|
||||
// generateClientId generate id based on client's meta infomation
|
||||
func (gs *gocServer) generateClientId(args ...string) gocCliendId {
|
||||
// generateAgentId generate id based on agent's meta infomation
|
||||
func (gs *gocServer) generateAgentId(args ...string) gocCliendId {
|
||||
var path string
|
||||
for _, arg := range args {
|
||||
path += arg
|
||||
|
@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"sync"
|
||||
"time"
|
||||
@ -15,8 +16,10 @@ type gocServer struct {
|
||||
storePath string
|
||||
upgrader websocket.Upgrader
|
||||
|
||||
rpcClients sync.Map
|
||||
// mu sync.Mutex // used to protect concurrent rpc call to agent
|
||||
rpcAgents sync.Map
|
||||
watchAgents sync.Map
|
||||
watchCh chan []byte
|
||||
watchClients sync.Map
|
||||
}
|
||||
|
||||
type gocCliendId string
|
||||
@ -34,6 +37,13 @@ type gocCoveredAgent struct {
|
||||
once sync.Once `json:"-"` // 保护 close(exitCh) 只执行一次
|
||||
}
|
||||
|
||||
// api 客户端,不是 agent
|
||||
type gocWatchClient struct {
|
||||
ws *websocket.Conn
|
||||
exitCh chan int
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func RunGocServerUntilExit(host string) {
|
||||
gs := gocServer{
|
||||
storePath: "",
|
||||
@ -41,7 +51,11 @@ func RunGocServerUntilExit(host string) {
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
HandshakeTimeout: 45 * time.Second,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
watchCh: make(chan []byte),
|
||||
}
|
||||
|
||||
r := gin.Default()
|
||||
@ -49,14 +63,17 @@ func RunGocServerUntilExit(host string) {
|
||||
{
|
||||
v2.GET("/cover/profile", gs.getProfiles)
|
||||
v2.DELETE("/cover/profile", gs.resetProfiles)
|
||||
v2.GET("/services", gs.listServices)
|
||||
v2.GET("/rpcagents", gs.listAgents)
|
||||
v2.GET("/watchagents", nil)
|
||||
|
||||
v2.GET("/cover/ws/watch", nil)
|
||||
v2.GET("/cover/ws/watch", gs.watchProfileUpdate)
|
||||
|
||||
// internal use only
|
||||
v2.GET("/internal/ws/rpcstream", gs.serveRpcStream)
|
||||
v2.GET("/internal/ws/watchstream", nil)
|
||||
v2.GET("/internal/ws/watchstream", gs.serveWatchInternalStream)
|
||||
}
|
||||
|
||||
go gs.watchLoop()
|
||||
|
||||
r.Run(host)
|
||||
}
|
||||
|
@ -1,9 +1,101 @@
|
||||
package server
|
||||
|
||||
// GocWatchCoverArg defines client -> server arg
|
||||
type GocWatchCoverArg struct {
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/qiniu/goc/v2/pkg/log"
|
||||
)
|
||||
|
||||
func (gs *gocServer) serveWatchInternalStream(c *gin.Context) {
|
||||
// 检查插桩服务上报的信息
|
||||
remoteIP, _ := c.RemoteIP()
|
||||
hostname := c.Query("hostname")
|
||||
pid := c.Query("pid")
|
||||
cmdline := c.Query("cmdline")
|
||||
|
||||
if hostname == "" || pid == "" || cmdline == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"msg": "missing some params",
|
||||
})
|
||||
return
|
||||
}
|
||||
// 计算插桩服务 id
|
||||
agentId := gs.generateAgentId(remoteIP.String(), hostname, cmdline, pid)
|
||||
// 检查 id 是否重复
|
||||
if _, ok := gs.watchAgents.Load(agentId); ok {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"msg": "the watch agent already exist",
|
||||
})
|
||||
return
|
||||
}
|
||||
// upgrade to websocket
|
||||
ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Errorf("fail to establish websocket connection with watch agent: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, nil)
|
||||
}
|
||||
|
||||
// send close msg and close ws connection
|
||||
defer func() {
|
||||
gs.watchAgents.Delete(agentId)
|
||||
ws.Close()
|
||||
log.Infof("close watch connection, %v", hostname)
|
||||
}()
|
||||
|
||||
// set pong handler
|
||||
ws.SetReadDeadline(time.Now().Add(PongWait))
|
||||
ws.SetPongHandler(func(string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(PongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
// set ping goroutine to ping every PingWait time
|
||||
go func() {
|
||||
ticker := time.NewTicker(PingWait)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := gs.wsping(ws, PongWait); err != nil {
|
||||
log.Errorf("watch ping to %v failed: %v", hostname, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("one watch agent established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), cmdline, pid, hostname)
|
||||
|
||||
for {
|
||||
mt, message, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
log.Errorf("read from %v: %v", hostname, err)
|
||||
break
|
||||
}
|
||||
if mt == websocket.TextMessage {
|
||||
// 非阻塞写
|
||||
select {
|
||||
case gs.watchCh <- message:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GocWatchCoverReply defines client -> server reply
|
||||
type GocWatchCoverReply struct {
|
||||
func (gs *gocServer) watchLoop() {
|
||||
for {
|
||||
msg := <-gs.watchCh
|
||||
gs.watchClients.Range(func(key, value interface{}) bool {
|
||||
// 这里是客户端的 ws 连接,不是 agent ws 连接
|
||||
gwc := value.(*gocWatchClient)
|
||||
err := gwc.ws.WriteMessage(websocket.TextMessage, msg)
|
||||
if err != nil {
|
||||
gwc.ws.Close()
|
||||
gwc.once.Do(func() { close(gwc.exitCh) })
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user