diff --git a/.gitignore b/.gitignore index e6c4f65..28b6025 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ tests/e2e/tmp/* .goc.kvstore -coverage.txt \ No newline at end of file +coverage.txt +.idea +repos \ No newline at end of file diff --git a/cmd/build.go b/cmd/build.go index e4822d3..9bd4c0b 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -14,40 +14,40 @@ package cmd import ( - "github.com/RickLeee/goc/v2/pkg/build" - "github.com/spf13/cobra" + "github.com/ar0c/goc/v2/pkg/build" + "github.com/spf13/cobra" ) var buildCmd = &cobra.Command{ - Use: "build", - Run: buildAction, + Use: "build", + Run: buildAction, - DisableFlagParsing: true, // build 命令需要用原生 go 的方式处理 flags + DisableFlagParsing: true, // build 命令需要用原生 go 的方式处理 flags } var ( - gocmode string - gochost string + gocmode string + gochost string ) func init() { - buildCmd.Flags().StringVarP(&gocmode, "gocmode", "", "count", "coverage mode: set, count, atomic, watch") - buildCmd.Flags().StringVarP(&gochost, "gochost", "", "127.0.0.1:7777", "specify the host of the goc sever") - rootCmd.AddCommand(buildCmd) + buildCmd.Flags().StringVarP(&gocmode, "gocmode", "", "count", "coverage mode: set, count, atomic, watch") + buildCmd.Flags().StringVarP(&gochost, "gochost", "", "127.0.0.1:7777", "specify the host of the goc sever") + rootCmd.AddCommand(buildCmd) } func buildAction(cmd *cobra.Command, args []string) { - sets := build.CustomParseCmdAndArgs(cmd, args) + sets := build.CustomParseCmdAndArgs(cmd, args) - b := build.NewBuild( - build.WithHost(gochost), - build.WithMode(gocmode), - build.WithFlagSets(sets), - build.WithArgs(args), - build.WithBuild(), - build.WithDebug(globalDebug), - ) - b.Build() + b := build.NewBuild( + build.WithHost(gochost), + build.WithMode(gocmode), + build.WithFlagSets(sets), + build.WithArgs(args), + build.WithBuild(), + build.WithDebug(globalDebug), + ) + b.Build() } diff --git a/cmd/inject.go b/cmd/inject.go index c41f0e6..74c8d02 100644 --- a/cmd/inject.go +++ b/cmd/inject.go @@ -14,7 +14,7 @@ package cmd import ( - "github.com/RickLeee/goc/v2/pkg/build" + "github.com/ar0c/goc/v2/pkg/build" "github.com/spf13/cobra" ) diff --git a/cmd/install.go b/cmd/install.go index 5b7df5c..0db5a08 100644 --- a/cmd/install.go +++ b/cmd/install.go @@ -14,35 +14,35 @@ package cmd import ( - "github.com/RickLeee/goc/v2/pkg/build" - "github.com/spf13/cobra" + "github.com/ar0c/goc/v2/pkg/build" + "github.com/spf13/cobra" ) var installCmd = &cobra.Command{ - Use: "install", - Run: installAction, + Use: "install", + Run: installAction, - DisableFlagParsing: true, // install 命令需要用原生 go 的方式处理 flags + DisableFlagParsing: true, // install 命令需要用原生 go 的方式处理 flags } func init() { - installCmd.Flags().StringVarP(&gocmode, "gocmode", "", "count", "coverage mode: set, count, atomic, watch") - installCmd.Flags().StringVarP(&gochost, "gochost", "", "127.0.0.1:7777", "specify the host of the goc sever") - rootCmd.AddCommand(installCmd) + installCmd.Flags().StringVarP(&gocmode, "gocmode", "", "count", "coverage mode: set, count, atomic, watch") + installCmd.Flags().StringVarP(&gochost, "gochost", "", "127.0.0.1:7777", "specify the host of the goc sever") + rootCmd.AddCommand(installCmd) } func installAction(cmd *cobra.Command, args []string) { - sets := build.CustomParseCmdAndArgs(cmd, args) + sets := build.CustomParseCmdAndArgs(cmd, args) - b := build.NewInstall( - build.WithHost(gochost), - build.WithMode(gocmode), - build.WithFlagSets(sets), - build.WithArgs(args), - build.WithInstall(), - build.WithDebug(globalDebug), - ) - b.Install() + b := build.NewInstall( + build.WithHost(gochost), + build.WithMode(gocmode), + build.WithFlagSets(sets), + build.WithArgs(args), + build.WithInstall(), + build.WithDebug(globalDebug), + ) + b.Install() } diff --git a/cmd/merge.go b/cmd/merge.go index 47631a3..6778630 100644 --- a/cmd/merge.go +++ b/cmd/merge.go @@ -14,55 +14,55 @@ package cmd import ( - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/spf13/cobra" - "golang.org/x/tools/cover" - "k8s.io/test-infra/gopherage/pkg/cov" - "k8s.io/test-infra/gopherage/pkg/util" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/spf13/cobra" + "golang.org/x/tools/cover" + "k8s.io/test-infra/gopherage/pkg/cov" + "k8s.io/test-infra/gopherage/pkg/util" ) var mergeCmd = &cobra.Command{ - Use: "merge [files...]", - Short: "Merge multiple coherent Go coverage files into a single file.", - Long: `Merge will merge multiple Go coverage files into a single coverage file. + Use: "merge [files...]", + Short: "Merge multiple coherent Go coverage files into a single file.", + Long: `Merge will merge multiple Go coverage files into a single coverage file. merge requires that the files are 'coherent', meaning that if they both contain references to the same paths, then the contents of those source files were identical for the binary that generated each file. `, - Run: func(cmd *cobra.Command, args []string) { - runMerge(args, outputMergeProfile) - }, + Run: func(cmd *cobra.Command, args []string) { + runMerge(args, outputMergeProfile) + }, } var outputMergeProfile string func init() { - mergeCmd.Flags().StringVarP(&outputMergeProfile, "output", "o", "mergeprofile.cov", "output file") + mergeCmd.Flags().StringVarP(&outputMergeProfile, "output", "o", "mergeprofile.cov", "output file") - rootCmd.AddCommand(mergeCmd) + rootCmd.AddCommand(mergeCmd) } func runMerge(args []string, output string) { - if len(args) == 0 { - log.Fatalf("Expected at least one coverage file.") - } + if len(args) == 0 { + log.Fatalf("Expected at least one coverage file.") + } - profiles := make([][]*cover.Profile, len(args)) - for _, path := range args { - profile, err := util.LoadProfile(path) - if err != nil { - log.Fatalf("failed to open %s: %v", path, err) - } - profiles = append(profiles, profile) - } + profiles := make([][]*cover.Profile, len(args)) + for _, path := range args { + profile, err := util.LoadProfile(path) + if err != nil { + log.Fatalf("failed to open %s: %v", path, err) + } + profiles = append(profiles, profile) + } - merged, err := cov.MergeMultipleProfiles(profiles) - if err != nil { - log.Fatalf("failed to merge files: %v", err) - } + merged, err := cov.MergeMultipleProfiles(profiles) + if err != nil { + log.Fatalf("failed to merge files: %v", err) + } - err = util.DumpProfile(output, merged) - if err != nil { - log.Fatalf("fail to dump the merged file: %v", err) - } + err = util.DumpProfile(output, merged) + if err != nil { + log.Fatalf("fail to dump the merged file: %v", err) + } } diff --git a/cmd/profile.go b/cmd/profile.go index 54f5a8d..5c82ee0 100644 --- a/cmd/profile.go +++ b/cmd/profile.go @@ -14,65 +14,65 @@ package cmd import ( - "github.com/RickLeee/goc/v2/pkg/client" - "github.com/spf13/cobra" - "github.com/spf13/pflag" + "github.com/ar0c/goc/v2/pkg/client" + "github.com/spf13/cobra" + "github.com/spf13/pflag" ) var profileCmd = &cobra.Command{ - Use: "profile", - Short: "Get coverage profile from service registry center", - Long: `Get code coverage profile for the services under test at runtime.`, - //Run: profile, + Use: "profile", + Short: "Get coverage profile from service registry center", + Long: `Get code coverage profile for the services under test at runtime.`, + //Run: profile, } var ( - profileHost string - profileOutput string // --output flag - profileIds []string - profileSkipPattern []string - profileExtra string - profileNeedPattern []string + profileHost string + profileOutput string // --output flag + profileIds []string + profileSkipPattern []string + profileExtra string + profileNeedPattern []string ) func init() { - add1Flags := func(f *pflag.FlagSet) { - f.StringVar(&profileHost, "host", "127.0.0.1:7777", "specify the host of the goc server") - f.StringSliceVar(&profileIds, "id", nil, "specify the ids of the services") - f.StringVar(&profileExtra, "extra", "", "specify the regex expression of extra, only profile with extra information will be downloaded") - } + add1Flags := func(f *pflag.FlagSet) { + f.StringVar(&profileHost, "host", "127.0.0.1:7777", "specify the host of the goc server") + f.StringSliceVar(&profileIds, "id", nil, "specify the ids of the services") + f.StringVar(&profileExtra, "extra", "", "specify the regex expression of extra, only profile with extra information will be downloaded") + } - add2Flags := func(f *pflag.FlagSet) { - f.StringVarP(&profileOutput, "output", "o", "", "download cover profile") - f.StringSliceVar(&profileSkipPattern, "skip", nil, "skip specific packages in the profile") - f.StringSliceVarP(&profileNeedPattern, "need", "n", nil, "find specific packages in the profile") - } + add2Flags := func(f *pflag.FlagSet) { + f.StringVarP(&profileOutput, "output", "o", "", "download cover profile") + f.StringSliceVar(&profileSkipPattern, "skip", nil, "skip specific packages in the profile") + f.StringSliceVarP(&profileNeedPattern, "need", "n", nil, "find specific packages in the profile") + } - add1Flags(getProfileCmd.Flags()) - add2Flags(getProfileCmd.Flags()) + add1Flags(getProfileCmd.Flags()) + add2Flags(getProfileCmd.Flags()) - add1Flags(clearProfileCmd.Flags()) + add1Flags(clearProfileCmd.Flags()) - profileCmd.AddCommand(getProfileCmd) - profileCmd.AddCommand(clearProfileCmd) - rootCmd.AddCommand(profileCmd) + profileCmd.AddCommand(getProfileCmd) + profileCmd.AddCommand(clearProfileCmd) + rootCmd.AddCommand(profileCmd) } var getProfileCmd = &cobra.Command{ - Use: "get", - Run: getProfile, + Use: "get", + Run: getProfile, } func getProfile(cmd *cobra.Command, args []string) { - client.GetProfile(profileHost, profileIds, profileSkipPattern, profileExtra, profileOutput, profileNeedPattern) + client.GetProfile(profileHost, profileIds, profileSkipPattern, profileExtra, profileOutput, profileNeedPattern) } var clearProfileCmd = &cobra.Command{ - Use: "clear", - Run: clearProfile, + Use: "clear", + Run: clearProfile, } func clearProfile(cmd *cobra.Command, args []string) { - client.ClearProfile(profileHost, profileIds, profileExtra) + client.ClearProfile(profileHost, profileIds, profileExtra) } diff --git a/cmd/root.go b/cmd/root.go index 4e376e9..9ffa482 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -14,38 +14,38 @@ package cmd import ( - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/spf13/cobra" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/spf13/cobra" ) var rootCmd = &cobra.Command{ - Use: "goc", - Short: "goc is a comprehensive coverage testing tool for go language", - Long: `goc is a comprehensive coverage testing tool for go language. + Use: "goc", + Short: "goc is a comprehensive coverage testing tool for go language", + Long: `goc is a comprehensive coverage testing tool for go language. Find more information at: https://github.com/qiniu/goc `, - PersistentPreRun: func(cmd *cobra.Command, args []string) { - //log.DisplayGoc() - // init logger - log.NewLogger(globalDebug) + PersistentPreRun: func(cmd *cobra.Command, args []string) { + //log.DisplayGoc() + // init logger + log.NewLogger(globalDebug) - }, + }, - PersistentPostRun: func(cmd *cobra.Command, args []string) { - log.Sync() - }, + PersistentPostRun: func(cmd *cobra.Command, args []string) { + log.Sync() + }, } var globalDebug bool func init() { - rootCmd.PersistentFlags().BoolVar(&globalDebug, "gocdebug", false, "run goc in debug mode") + rootCmd.PersistentFlags().BoolVar(&globalDebug, "gocdebug", false, "run goc in debug mode") } // Execute the goc tool func Execute() { - if err := rootCmd.Execute(); err != nil { - } + if err := rootCmd.Execute(); err != nil { + } } diff --git a/cmd/run.go b/cmd/run.go index 2e02e84..ca9a725 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -14,34 +14,34 @@ package cmd import ( - "github.com/RickLeee/goc/v2/pkg/build" - "github.com/spf13/cobra" + "github.com/ar0c/goc/v2/pkg/build" + "github.com/spf13/cobra" ) var runCmd = &cobra.Command{ - Use: "run", - Run: runAction, + Use: "run", + Run: runAction, - DisableFlagParsing: true, // run 命令需要用原生 go 的方式处理 flags + DisableFlagParsing: true, // run 命令需要用原生 go 的方式处理 flags } func init() { - runCmd.Flags().StringVarP(&gocmode, "gocmode", "", "count", "coverage mode: set, count, atomic, watch") - runCmd.Flags().StringVarP(&gochost, "gochost", "", "127.0.0.1:7777", "specify the host of the goc sever") - rootCmd.AddCommand(runCmd) + runCmd.Flags().StringVarP(&gocmode, "gocmode", "", "count", "coverage mode: set, count, atomic, watch") + runCmd.Flags().StringVarP(&gochost, "gochost", "", "127.0.0.1:7777", "specify the host of the goc sever") + rootCmd.AddCommand(runCmd) } func runAction(cmd *cobra.Command, args []string) { - sets := build.CustomParseCmdAndArgs(cmd, args) + sets := build.CustomParseCmdAndArgs(cmd, args) - b := build.NewRun( - build.WithHost(gochost), - build.WithMode(gocmode), - build.WithFlagSets(sets), - build.WithArgs(args), - build.WithBuild(), - build.WithDebug(globalDebug), - ) - b.Run() + b := build.NewRun( + build.WithHost(gochost), + build.WithMode(gocmode), + build.WithFlagSets(sets), + build.WithArgs(args), + build.WithBuild(), + build.WithDebug(globalDebug), + ) + b.Run() } diff --git a/cmd/server.go b/cmd/server.go index dc1161c..f277cf9 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -14,36 +14,36 @@ package cmd import ( - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/RickLeee/goc/v2/pkg/server" - "github.com/RickLeee/goc/v2/pkg/server/store" - "github.com/spf13/cobra" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/ar0c/goc/v2/pkg/server" + "github.com/ar0c/goc/v2/pkg/server/store" + "github.com/spf13/cobra" ) var serverCmd = &cobra.Command{ - Use: "server", - Short: "Start a service registry center", - Example: "", + Use: "server", + Short: "Start a service registry center", + Example: "", - Run: serve, + Run: serve, } var ( - serverHost string - serverStore string + serverHost string + serverStore string ) func init() { - serverCmd.Flags().StringVarP(&serverHost, "host", "", "127.0.0.1:7777", "specify the host of the goc server") - serverCmd.Flags().StringVarP(&serverStore, "store", "", ".goc.kvstore", "specify the host of the goc server") + serverCmd.Flags().StringVarP(&serverHost, "host", "", "127.0.0.1:7777", "specify the host of the goc server") + serverCmd.Flags().StringVarP(&serverStore, "store", "", ".goc.kvstore", "specify the host of the goc server") - rootCmd.AddCommand(serverCmd) + rootCmd.AddCommand(serverCmd) } func serve(cmd *cobra.Command, args []string) { - s, err := store.NewFileStore(serverStore) - if err != nil { - log.Fatalf("cannot create store for goc server: %v", err) - } - server.RunGocServerUntilExit(serverHost, s) + s, err := store.NewFileStore(serverStore) + if err != nil { + log.Fatalf("cannot create store for goc server: %v", err) + } + server.RunGocServerUntilExit(serverHost, s) } diff --git a/cmd/server_test.go b/cmd/server_test.go new file mode 100644 index 0000000..617ad34 --- /dev/null +++ b/cmd/server_test.go @@ -0,0 +1,30 @@ +package cmd + +import ( + "github.com/ar0c/goc/v2/pkg/log" + "github.com/spf13/cobra" + "testing" +) + +func Test_serve(t *testing.T) { + type args struct { + cmd *cobra.Command + args []string + } + tests := []struct { + name string + args args + }{ + { + name: "", + args: args{}, + }, + // TODO: Add test cases. + } + for _, tt := range tests { + log.NewLogger(true) + t.Run(tt.name, func(t *testing.T) { + serve(tt.args.cmd, tt.args.args) + }) + } +} diff --git a/cmd/service.go b/cmd/service.go index bad40ac..70fa47e 100644 --- a/cmd/service.go +++ b/cmd/service.go @@ -14,61 +14,61 @@ package cmd import ( - "github.com/RickLeee/goc/v2/pkg/client" - "github.com/spf13/cobra" - "github.com/spf13/pflag" + "github.com/ar0c/goc/v2/pkg/client" + "github.com/spf13/cobra" + "github.com/spf13/pflag" ) var listCmd = &cobra.Command{ - Use: "service", - Short: "Deal with the registered services", - Long: `It can be used to list, remove the registered services. + Use: "service", + Short: "Deal with the registered services", + Long: `It can be used to list, remove the registered services. For disconnected services, remove will delete these serivces forever, for connected services remove will force these services register again.`, } var ( - listHost string - listWide bool - listIds []string - listJson bool + listHost string + listWide bool + listIds []string + listJson bool ) func init() { - add1Flags := func(f *pflag.FlagSet) { - f.StringVar(&listHost, "host", "127.0.0.1:7777", "specify the host of the goc server") - f.BoolVar(&listWide, "wide", false, "list all services with more information (such as pid)") - f.BoolVar(&listJson, "json", false, "list all services info as json format") - f.StringSliceVar(&listIds, "id", nil, "specify the ids of the services") - } + add1Flags := func(f *pflag.FlagSet) { + f.StringVar(&listHost, "host", "127.0.0.1:7777", "specify the host of the goc server") + f.BoolVar(&listWide, "wide", false, "list all services with more information (such as pid)") + f.BoolVar(&listJson, "json", false, "list all services info as json format") + f.StringSliceVar(&listIds, "id", nil, "specify the ids of the services") + } - add1Flags(getServiceCmd.Flags()) - add1Flags(deleteServiceCmd.Flags()) + add1Flags(getServiceCmd.Flags()) + add1Flags(deleteServiceCmd.Flags()) - listCmd.AddCommand(getServiceCmd) - listCmd.AddCommand(deleteServiceCmd) - rootCmd.AddCommand(listCmd) + listCmd.AddCommand(getServiceCmd) + listCmd.AddCommand(deleteServiceCmd) + rootCmd.AddCommand(listCmd) } func list(cmd *cobra.Command, args []string) { - client.ListAgents(listHost, listIds, listWide, listJson) + client.ListAgents(listHost, listIds, listWide, listJson) } var getServiceCmd = &cobra.Command{ - Use: "get", - Run: getAgents, + Use: "get", + Run: getAgents, } func getAgents(cmd *cobra.Command, args []string) { - client.ListAgents(listHost, listIds, listWide, listJson) + client.ListAgents(listHost, listIds, listWide, listJson) } var deleteServiceCmd = &cobra.Command{ - Use: "delete", - Run: deleteAgents, + Use: "delete", + Run: deleteAgents, } func deleteAgents(cmd *cobra.Command, args []string) { - client.DeleteAgents(listHost, listIds) + client.DeleteAgents(listHost, listIds) } diff --git a/cmd/watch.go b/cmd/watch.go index d968d5e..17a47a7 100644 --- a/cmd/watch.go +++ b/cmd/watch.go @@ -14,28 +14,28 @@ package cmd import ( - cli "github.com/RickLeee/goc/v2/pkg/watch" - "github.com/spf13/cobra" + cli "github.com/ar0c/goc/v2/pkg/watch" + "github.com/spf13/cobra" ) var watchCmd = &cobra.Command{ - Use: "watch", - Short: "watch for profile's real time update", - Long: "watch for profile's real time update", - Example: "", + Use: "watch", + Short: "watch for profile's real time update", + Long: "watch for profile's real time update", + Example: "", - Run: watch, + Run: watch, } var ( - watchHost string + watchHost string ) func init() { - watchCmd.Flags().StringVarP(&watchHost, "host", "", "127.0.0.1:7777", "specify the host of the goc server") - rootCmd.AddCommand(watchCmd) + watchCmd.Flags().StringVarP(&watchHost, "host", "", "127.0.0.1:7777", "specify the host of the goc server") + rootCmd.AddCommand(watchCmd) } func watch(cmd *cobra.Command, args []string) { - cli.Watch(watchHost) + cli.Watch(watchHost) } diff --git a/go.mod b/go.mod index 329357a..e71b6b5 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/RickLeee/goc/v2 +module github.com/ar0c/goc/v2 go 1.22.0 diff --git a/hack/release.sh b/hack/release.sh index 3b56ad3..8c58299 100755 --- a/hack/release.sh +++ b/hack/release.sh @@ -10,7 +10,7 @@ RELEASE_VERSION=$(echo $EVENT_DATA | jq -r .release.tag_name) PROJECT_NAME=$(basename $GITHUB_REPOSITORY) NAME="${NAME:-${PROJECT_NAME}-${RELEASE_VERSION}}-${GOOS}-${GOARCH}" -CGO_ENABLED=0 go build -o goc -ldflags "-X 'github.com/RickLeee/goc/v2/cmd.Version=${RELEASE_VERSION}'" . +CGO_ENABLED=0 go build -o goc -ldflags "-X 'github.com/ar0c/goc/v2/cmd.Version=${RELEASE_VERSION}'" . ARCHIVE=tmp.tar.gz FILE_LIST=goc diff --git a/main.go b/main.go index c18784f..bdac580 100644 --- a/main.go +++ b/main.go @@ -14,9 +14,9 @@ package main import ( - "github.com/RickLeee/goc/v2/cmd" + "github.com/ar0c/goc/v2/cmd" ) func main() { - cmd.Execute() + cmd.Execute() } diff --git a/pkg/build/build.go b/pkg/build/build.go index 4e8835a..3e184c8 100644 --- a/pkg/build/build.go +++ b/pkg/build/build.go @@ -18,7 +18,7 @@ import ( "os/exec" "strings" - "github.com/RickLeee/goc/v2/pkg/log" + "github.com/ar0c/goc/v2/pkg/log" "github.com/spf13/pflag" ) diff --git a/pkg/build/build_flags.go b/pkg/build/build_flags.go index c63b8d5..43a8c94 100644 --- a/pkg/build/build_flags.go +++ b/pkg/build/build_flags.go @@ -14,15 +14,15 @@ package build import ( - "flag" - "fmt" - "os" - "path/filepath" - "strings" + "flag" + "fmt" + "os" + "path/filepath" + "strings" - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/spf13/cobra" - "github.com/spf13/pflag" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/spf13/cobra" + "github.com/spf13/pflag" ) var buildUsage string = `Usage: @@ -53,23 +53,23 @@ However, other flags' order are same with the go official command. ` const ( - GO_BUILD = iota - GO_INSTALL + GO_BUILD = iota + GO_INSTALL ) // CustomParseCmdAndArgs 因为关闭了 cobra 的解析功能,需要手动构造并解析 goc flags func CustomParseCmdAndArgs(cmd *cobra.Command, args []string) *pflag.FlagSet { - // 首先解析 cobra 定义的 flag - allFlagSets := cmd.Flags() - // 因为 args 里面含有 go 的 flag,所以需要忽略解析 go flag 的错误 - allFlagSets.Init("GOC", pflag.ContinueOnError) - // 忽略 go flag 在 goc 中的解析错误 - allFlagSets.ParseErrorsWhitelist = pflag.ParseErrorsWhitelist{ - UnknownFlags: true, - } - allFlagSets.Parse(args) + // 首先解析 cobra 定义的 flag + allFlagSets := cmd.Flags() + // 因为 args 里面含有 go 的 flag,所以需要忽略解析 go flag 的错误 + allFlagSets.Init("GOC", pflag.ContinueOnError) + // 忽略 go flag 在 goc 中的解析错误 + allFlagSets.ParseErrorsWhitelist = pflag.ParseErrorsWhitelist{ + UnknownFlags: true, + } + allFlagSets.Parse(args) - return allFlagSets + return allFlagSets } // buildCmdArgsParse parse both go flags and goc flags, it rewrite go flags if @@ -77,246 +77,246 @@ func CustomParseCmdAndArgs(cmd *cobra.Command, args []string) *pflag.FlagSet { // // 吞下 [packages] 之前所有的 flags. func (b *Build) buildCmdArgsParse() { - args := b.Args - cmdType := b.BuildType - allFlagSets := b.FlagSets + args := b.Args + cmdType := b.BuildType + allFlagSets := b.FlagSets - // 重写 help - helpFlag := allFlagSets.Lookup("help") + // 重写 help + helpFlag := allFlagSets.Lookup("help") - if helpFlag.Changed { - if cmdType == GO_BUILD { - printGoHelp(buildUsage) - } else if cmdType == GO_INSTALL { - printGoHelp(installUsage) - } + if helpFlag.Changed { + if cmdType == GO_BUILD { + printGoHelp(buildUsage) + } else if cmdType == GO_INSTALL { + printGoHelp(installUsage) + } - os.Exit(0) - } - // 删除 help flag - args = findAndDelHelpFlag(args) + os.Exit(0) + } + // 删除 help flag + args = findAndDelHelpFlag(args) - // 必须手动调用 - // 由于关闭了 cobra 的 flag parse,root PersistentPreRun 调用时,log.NewLogger 并没有拿到 debug 值 - log.NewLogger(b.Debug) + // 必须手动调用 + // 由于关闭了 cobra 的 flag parse,root PersistentPreRun 调用时,log.NewLogger 并没有拿到 debug 值 + log.NewLogger(b.Debug) - // 删除 cobra 定义的 flag - allFlagSets.Visit(func(f *pflag.Flag) { - args = findAndDelGocFlag(args, f.Name, f.Value.String()) - }) + // 删除 cobra 定义的 flag + allFlagSets.Visit(func(f *pflag.Flag) { + args = findAndDelGocFlag(args, f.Name, f.Value.String()) + }) - // 然后解析 go 的 flag - goFlagSets := flag.NewFlagSet("GO", flag.ContinueOnError) - addBuildFlags(goFlagSets) - addOutputFlags(goFlagSets) - err := goFlagSets.Parse(args) - if err != nil { - log.Fatalf("%v", err) - } + // 然后解析 go 的 flag + goFlagSets := flag.NewFlagSet("GO", flag.ContinueOnError) + addBuildFlags(goFlagSets) + addOutputFlags(goFlagSets) + err := goFlagSets.Parse(args) + if err != nil { + log.Fatalf("%v", err) + } - // 找出设置的 go flag - curWd, err := os.Getwd() - if err != nil { - log.Fatalf("fail to get current working directory: %v", err) - } - flags := make([]string, 0) - goFlagSets.Visit(func(f *flag.Flag) { - // 将用户指定 -o 改成绝对目录 - if f.Name == "o" { - outputDir := f.Value.String() - outputDir, err := filepath.Abs(outputDir) - if err != nil { - log.Fatalf("output flag is not valid: %v", err) - } - flags = append(flags, "-o", outputDir) - } else { - if _, ok := booleanFlags[f.Name]; !ok { - flags = append(flags, "-"+f.Name, f.Value.String()) - } else { - flags = append(flags, "-"+f.Name) - } - if f.Name == "mod" { - if f.Value.String() == "vendor" { - b.IsVendorMod = true - } else { - b.IsVendorMod = false - } - } - } - }) + // 找出设置的 go flag + curWd, err := os.Getwd() + if err != nil { + log.Fatalf("fail to get current working directory: %v", err) + } + flags := make([]string, 0) + goFlagSets.Visit(func(f *flag.Flag) { + // 将用户指定 -o 改成绝对目录 + if f.Name == "o" { + outputDir := f.Value.String() + outputDir, err := filepath.Abs(outputDir) + if err != nil { + log.Fatalf("output flag is not valid: %v", err) + } + flags = append(flags, "-o", outputDir) + } else { + if _, ok := booleanFlags[f.Name]; !ok { + flags = append(flags, "-"+f.Name, f.Value.String()) + } else { + flags = append(flags, "-"+f.Name) + } + if f.Name == "mod" { + if f.Value.String() == "vendor" { + b.IsVendorMod = true + } else { + b.IsVendorMod = false + } + } + } + }) - b.Goflags = flags - b.CurWd = curWd - b.GoArgs = goFlagSets.Args() - return + b.Goflags = flags + b.CurWd = curWd + b.GoArgs = goFlagSets.Args() + return } func (b *Build) runCmdArgsParse() { - args := b.Args - allFlagSets := b.FlagSets + args := b.Args + allFlagSets := b.FlagSets - // 重写 help - helpFlag := allFlagSets.Lookup("help") + // 重写 help + helpFlag := allFlagSets.Lookup("help") - if helpFlag.Changed { - printGoHelp(runUsage) - os.Exit(0) - } + if helpFlag.Changed { + printGoHelp(runUsage) + os.Exit(0) + } - // 删除 help flag - args = findAndDelHelpFlag(args) + // 删除 help flag + args = findAndDelHelpFlag(args) - // 必须手动调用 - // 由于关闭了 cobra 的 flag parse,root PersistentPreRun 调用时,log.NewLogger 并没有拿到 debug 值 - log.NewLogger(b.Debug) + // 必须手动调用 + // 由于关闭了 cobra 的 flag parse,root PersistentPreRun 调用时,log.NewLogger 并没有拿到 debug 值 + log.NewLogger(b.Debug) - curWd, err := os.Getwd() - if err != nil { - log.Fatalf("fail to get current working directory: %v", err) - } - b.CurWd = curWd + curWd, err := os.Getwd() + if err != nil { + log.Fatalf("fail to get current working directory: %v", err) + } + b.CurWd = curWd - // 获取除 goc flags 之外的 args - // 删除 cobra 定义的 flag - allFlagSets.Visit(func(f *pflag.Flag) { - args = findAndDelGocFlag(args, f.Name, f.Value.String()) - }) + // 获取除 goc flags 之外的 args + // 删除 cobra 定义的 flag + allFlagSets.Visit(func(f *pflag.Flag) { + args = findAndDelGocFlag(args, f.Name, f.Value.String()) + }) - b.GoArgs = args + b.GoArgs = args } func findAndDelGocFlag(a []string, x string, v string) []string { - new := make([]string, 0, len(a)) - x = "--" + x - x_v := x + "=" + v - for i := 0; i < len(a); i++ { - if a[i] == "--gocdebug" { - // debug 是 bool,就一个元素 - continue - } else if a[i] == x { - // 有 goc flag 长这样 --mode watch - i++ - continue - } else if a[i] == x_v { - // 有 goc flag 长这样 --mode=watch - continue - } else { - // 剩下的是 go flag - new = append(new, a[i]) - } - } + new := make([]string, 0, len(a)) + x = "--" + x + x_v := x + "=" + v + for i := 0; i < len(a); i++ { + if a[i] == "--gocdebug" { + // debug 是 bool,就一个元素 + continue + } else if a[i] == x { + // 有 goc flag 长这样 --mode watch + i++ + continue + } else if a[i] == x_v { + // 有 goc flag 长这样 --mode=watch + continue + } else { + // 剩下的是 go flag + new = append(new, a[i]) + } + } - return new + return new } func findAndDelHelpFlag(a []string) []string { - new := make([]string, 0, len(a)) - for _, v := range a { - if v == "--help" || v == "-h" { - continue - } else { - new = append(new, v) - } - } + new := make([]string, 0, len(a)) + for _, v := range a { + if v == "--help" || v == "-h" { + continue + } else { + new = append(new, v) + } + } - return new + return new } type goConfig struct { - BuildA bool - BuildBuildmode string // -buildmode flag - BuildMod string // -mod flag - BuildModReason string // reason -mod flag is set, if set by default - BuildI bool // -i flag - BuildLinkshared bool // -linkshared flag - BuildMSan bool // -msan flag - BuildN bool // -n flag - BuildO string // -o flag - BuildP int // -p flag - BuildPkgdir string // -pkgdir flag - BuildRace bool // -race flag - BuildToolexec string // -toolexec flag - BuildToolchainName string - BuildToolchainCompiler func() string - BuildToolchainLinker func() string - BuildTrimpath bool // -trimpath flag - BuildV bool // -v flag - BuildWork bool // -work flag - BuildX bool // -x flag - // from buildcontext - Installsuffix string // -installSuffix - BuildTags string // -tags - // from load - BuildAsmflags string - BuildCompiler string - BuildGcflags string - BuildGccgoflags string - BuildLdflags string + BuildA bool + BuildBuildmode string // -buildmode flag + BuildMod string // -mod flag + BuildModReason string // reason -mod flag is set, if set by default + BuildI bool // -i flag + BuildLinkshared bool // -linkshared flag + BuildMSan bool // -msan flag + BuildN bool // -n flag + BuildO string // -o flag + BuildP int // -p flag + BuildPkgdir string // -pkgdir flag + BuildRace bool // -race flag + BuildToolexec string // -toolexec flag + BuildToolchainName string + BuildToolchainCompiler func() string + BuildToolchainLinker func() string + BuildTrimpath bool // -trimpath flag + BuildV bool // -v flag + BuildWork bool // -work flag + BuildX bool // -x flag + // from buildcontext + Installsuffix string // -installSuffix + BuildTags string // -tags + // from load + BuildAsmflags string + BuildCompiler string + BuildGcflags string + BuildGccgoflags string + BuildLdflags string - // mod related - ModCacheRW bool - ModFile string + // mod related + ModCacheRW bool + ModFile string } var goflags goConfig var booleanFlags map[string]struct{} = make(map[string]struct{}) func addBuildFlags(cmdSet *flag.FlagSet) { - cmdSet.BoolVar(&goflags.BuildA, "a", false, "") - booleanFlags["a"] = struct{}{} - cmdSet.BoolVar(&goflags.BuildN, "n", false, "") - booleanFlags["n"] = struct{}{} - cmdSet.IntVar(&goflags.BuildP, "p", 4, "") - cmdSet.BoolVar(&goflags.BuildV, "v", false, "") - booleanFlags["v"] = struct{}{} - cmdSet.BoolVar(&goflags.BuildX, "x", false, "") - booleanFlags["x"] = struct{}{} - cmdSet.StringVar(&goflags.BuildBuildmode, "buildmode", "default", "") - cmdSet.StringVar(&goflags.BuildMod, "mod", "", "") - cmdSet.StringVar(&goflags.Installsuffix, "installsuffix", "", "") + cmdSet.BoolVar(&goflags.BuildA, "a", false, "") + booleanFlags["a"] = struct{}{} + cmdSet.BoolVar(&goflags.BuildN, "n", false, "") + booleanFlags["n"] = struct{}{} + cmdSet.IntVar(&goflags.BuildP, "p", 4, "") + cmdSet.BoolVar(&goflags.BuildV, "v", false, "") + booleanFlags["v"] = struct{}{} + cmdSet.BoolVar(&goflags.BuildX, "x", false, "") + booleanFlags["x"] = struct{}{} + cmdSet.StringVar(&goflags.BuildBuildmode, "buildmode", "default", "") + cmdSet.StringVar(&goflags.BuildMod, "mod", "", "") + cmdSet.StringVar(&goflags.Installsuffix, "installsuffix", "", "") - // 类型和 go 原生的不一样,这里纯粹是为了 parse 并传递给 go - cmdSet.StringVar(&goflags.BuildAsmflags, "asmflags", "", "") - cmdSet.StringVar(&goflags.BuildCompiler, "compiler", "", "") - cmdSet.StringVar(&goflags.BuildGcflags, "gcflags", "", "") - cmdSet.StringVar(&goflags.BuildGccgoflags, "gccgoflags", "", "") - // mod related - cmdSet.BoolVar(&goflags.ModCacheRW, "modcacherw", false, "") - booleanFlags["modcacherw"] = struct{}{} - cmdSet.StringVar(&goflags.ModFile, "modfile", "", "") - cmdSet.StringVar(&goflags.BuildLdflags, "ldflags", "", "") - cmdSet.BoolVar(&goflags.BuildLinkshared, "linkshared", false, "") - booleanFlags["linkshared"] = struct{}{} - cmdSet.StringVar(&goflags.BuildPkgdir, "pkgdir", "", "") - cmdSet.BoolVar(&goflags.BuildRace, "race", false, "") - booleanFlags["race"] = struct{}{} - cmdSet.BoolVar(&goflags.BuildMSan, "msan", false, "") - booleanFlags["msan"] = struct{}{} - cmdSet.StringVar(&goflags.BuildTags, "tags", "", "") - cmdSet.StringVar(&goflags.BuildToolexec, "toolexec", "", "") - cmdSet.BoolVar(&goflags.BuildTrimpath, "trimpath", false, "") - booleanFlags["trimpath"] = struct{}{} - cmdSet.BoolVar(&goflags.BuildWork, "work", false, "") - booleanFlags["work"] = struct{}{} + // 类型和 go 原生的不一样,这里纯粹是为了 parse 并传递给 go + cmdSet.StringVar(&goflags.BuildAsmflags, "asmflags", "", "") + cmdSet.StringVar(&goflags.BuildCompiler, "compiler", "", "") + cmdSet.StringVar(&goflags.BuildGcflags, "gcflags", "", "") + cmdSet.StringVar(&goflags.BuildGccgoflags, "gccgoflags", "", "") + // mod related + cmdSet.BoolVar(&goflags.ModCacheRW, "modcacherw", false, "") + booleanFlags["modcacherw"] = struct{}{} + cmdSet.StringVar(&goflags.ModFile, "modfile", "", "") + cmdSet.StringVar(&goflags.BuildLdflags, "ldflags", "", "") + cmdSet.BoolVar(&goflags.BuildLinkshared, "linkshared", false, "") + booleanFlags["linkshared"] = struct{}{} + cmdSet.StringVar(&goflags.BuildPkgdir, "pkgdir", "", "") + cmdSet.BoolVar(&goflags.BuildRace, "race", false, "") + booleanFlags["race"] = struct{}{} + cmdSet.BoolVar(&goflags.BuildMSan, "msan", false, "") + booleanFlags["msan"] = struct{}{} + cmdSet.StringVar(&goflags.BuildTags, "tags", "", "") + cmdSet.StringVar(&goflags.BuildToolexec, "toolexec", "", "") + cmdSet.BoolVar(&goflags.BuildTrimpath, "trimpath", false, "") + booleanFlags["trimpath"] = struct{}{} + cmdSet.BoolVar(&goflags.BuildWork, "work", false, "") + booleanFlags["work"] = struct{}{} } func addOutputFlags(cmdSet *flag.FlagSet) { - cmdSet.StringVar(&goflags.BuildO, "o", "", "") + cmdSet.StringVar(&goflags.BuildO, "o", "", "") } func printGoHelp(usage string) { - fmt.Println(usage) + fmt.Println(usage) } func printGocHelp(cmd *cobra.Command) { - flags := cmd.LocalFlags() - globalFlags := cmd.Parent().PersistentFlags() + flags := cmd.LocalFlags() + globalFlags := cmd.Parent().PersistentFlags() - fmt.Println("Flags:") - fmt.Println(flags.FlagUsages()) + fmt.Println("Flags:") + fmt.Println(flags.FlagUsages()) - fmt.Println("Global Flags:") - fmt.Println(globalFlags.FlagUsages()) + fmt.Println("Global Flags:") + fmt.Println(globalFlags.FlagUsages()) } // GetPackagesDir parse [pacakges] part of args, it will fatal if error encountered @@ -328,43 +328,43 @@ func printGocHelp(cmd *cobra.Command) { // 如果 [packages] 非法(即不符合 go 原生的定义),则返回对应错误 // 这里只考虑 go mod 的方式 func (b *Build) getPackagesDir() { - patterns := b.GoArgs - packages := make([]string, 0) - for _, p := range patterns { - // patterns 只支持两种格式 - // 1. 要么是直接指向某些 .go 文件的相对/绝对路径 - if strings.HasSuffix(p, ".go") { - if fi, err := os.Stat(p); err == nil && !fi.IsDir() { - // check if valid - if err := goFilesPackage(patterns); err != nil { - log.Fatalf("%v", err) - } + patterns := b.GoArgs + packages := make([]string, 0) + for _, p := range patterns { + // patterns 只支持两种格式 + // 1. 要么是直接指向某些 .go 文件的相对/绝对路径 + if strings.HasSuffix(p, ".go") { + if fi, err := os.Stat(p); err == nil && !fi.IsDir() { + // check if valid + if err := goFilesPackage(patterns); err != nil { + log.Fatalf("%v", err) + } - // 获取相对于 current working directory 对路径 - for _, p := range patterns { - if filepath.IsAbs(p) { - relPath, err := filepath.Rel(b.CurWd, p) - if err != nil { - log.Fatalf("fail to get [packages] relative path from current working directory: %v", err) - } - packages = append(packages, relPath) - } else { - packages = append(packages, p) - } - } - // fix: go build ./xx/main.go 需要转换为 - // go build ./xx/main.go ./xx/goc-cover-agent-apis-auto-generated-11111-22222-bridge.go - dir := filepath.Dir(packages[0]) - packages = append(packages, filepath.Join(dir, "goc-cover-agent-apis-auto-generated-11111-22222-bridge.go")) - b.Packages = packages + // 获取相对于 current working directory 对路径 + for _, p := range patterns { + if filepath.IsAbs(p) { + relPath, err := filepath.Rel(b.CurWd, p) + if err != nil { + log.Fatalf("fail to get [packages] relative path from current working directory: %v", err) + } + packages = append(packages, relPath) + } else { + packages = append(packages, p) + } + } + // fix: go build ./xx/main.go 需要转换为 + // go build ./xx/main.go ./xx/goc-cover-agent-apis-auto-generated-11111-22222-bridge.go + dir := filepath.Dir(packages[0]) + packages = append(packages, filepath.Join(dir, "goc-cover-agent-apis-auto-generated-11111-22222-bridge.go")) + b.Packages = packages - return - } - } - } + return + } + } + } - // 2. 要么是 import path - b.Packages = patterns + // 2. 要么是 import path + b.Packages = patterns } // goFilesPackage 对一组 go 文件解析,判断是否合法 @@ -373,40 +373,40 @@ func (b *Build) getPackagesDir() { // 2. *.go 文件都在同一个目录? // 3. *.go 文件存在? func goFilesPackage(gofiles []string) error { - // 1. 必须都是 *.go 结尾 - for _, f := range gofiles { - if !strings.HasSuffix(f, ".go") { - return fmt.Errorf("named files must be .go files: %s", f) - } - } + // 1. 必须都是 *.go 结尾 + for _, f := range gofiles { + if !strings.HasSuffix(f, ".go") { + return fmt.Errorf("named files must be .go files: %s", f) + } + } - var dir string - for _, file := range gofiles { - // 3. 文件都存在? - fi, err := os.Stat(file) - if err != nil { - return err - } + var dir string + for _, file := range gofiles { + // 3. 文件都存在? + fi, err := os.Stat(file) + if err != nil { + return err + } - // 2.1 有可能以 *.go 结尾的目录 - if fi.IsDir() { - return fmt.Errorf("%s is a directory, should be a Go file", file) - } + // 2.1 有可能以 *.go 结尾的目录 + if fi.IsDir() { + return fmt.Errorf("%s is a directory, should be a Go file", file) + } - // 2.2 所有 *.go 必须在同一个目录内 - dir1, _ := filepath.Split(file) - if dir1 == "" { - dir1 = "./" - } + // 2.2 所有 *.go 必须在同一个目录内 + dir1, _ := filepath.Split(file) + if dir1 == "" { + dir1 = "./" + } - if dir == "" { - dir = dir1 - } else if dir != dir1 { - return fmt.Errorf("named files must all be in one directory: have %s and %s", dir, dir1) - } - } + if dir == "" { + dir = dir1 + } else if dir != dir1 { + return fmt.Errorf("named files must all be in one directory: have %s and %s", dir, dir1) + } + } - return nil + return nil } // getDirFromImportPaths return the import path's real abs directory @@ -414,90 +414,90 @@ func goFilesPackage(gofiles []string) error { // 该函数接收到的只有 dir 或 import path,file 在上一步已被排除 // 只考虑 go modules 的情况 func getDirFromImportPaths(patterns []string) (string, error) { - // no import path, pattern = current wd - if len(patterns) == 0 { - wd, err := os.Getwd() - if err != nil { - return "", fmt.Errorf("fail to parse import path: %w", err) - } - return wd, nil - } + // no import path, pattern = current wd + if len(patterns) == 0 { + wd, err := os.Getwd() + if err != nil { + return "", fmt.Errorf("fail to parse import path: %w", err) + } + return wd, nil + } - // 为了简化插桩的逻辑,goc 对 import path 要求必须都在同一个目录 - // 所以干脆只允许一个 pattern 得了 -_- - // 对于 goc build/run 来说本身就是只能在一个目录内 - // 对于 goc install 来讲,这个行为就和 go install 不同,不过多 import path 较少见 >_<,先忽略 - if len(patterns) > 1 { - return "", fmt.Errorf("goc only support one import path now") - } + // 为了简化插桩的逻辑,goc 对 import path 要求必须都在同一个目录 + // 所以干脆只允许一个 pattern 得了 -_- + // 对于 goc build/run 来说本身就是只能在一个目录内 + // 对于 goc install 来讲,这个行为就和 go install 不同,不过多 import path 较少见 >_<,先忽略 + if len(patterns) > 1 { + return "", fmt.Errorf("goc only support one import path now") + } - pattern := patterns[0] - switch { - // case isLocalImport(pattern) || filepath.IsAbs(pattern): - // dir1, err := filepath.Abs(pattern) - // if err != nil { - // return "", fmt.Errorf("error (%w) get directory from the import path: %v", err, pattern) - // } - // if _, err := os.Stat(dir1); err != nil { - // return "", fmt.Errorf("error (%w) get directory from the import path: %v", err, pattern) - // } - // return dir1, nil + pattern := patterns[0] + switch { + // case isLocalImport(pattern) || filepath.IsAbs(pattern): + // dir1, err := filepath.Abs(pattern) + // if err != nil { + // return "", fmt.Errorf("error (%w) get directory from the import path: %v", err, pattern) + // } + // if _, err := os.Stat(dir1); err != nil { + // return "", fmt.Errorf("error (%w) get directory from the import path: %v", err, pattern) + // } + // return dir1, nil - case strings.Contains(pattern, "..."): - i := strings.Index(pattern, "...") - dir, _ := filepath.Split(pattern[:i]) - dir, _ = filepath.Abs(dir) - if _, err := os.Stat(dir); err != nil { - return "", fmt.Errorf("error (%w) get directory from the import path: %v", err, pattern) - } - return dir, nil + case strings.Contains(pattern, "..."): + i := strings.Index(pattern, "...") + dir, _ := filepath.Split(pattern[:i]) + dir, _ = filepath.Abs(dir) + if _, err := os.Stat(dir); err != nil { + return "", fmt.Errorf("error (%w) get directory from the import path: %v", err, pattern) + } + return dir, nil - case strings.IndexByte(pattern, '@') > 0: - return "", fmt.Errorf("import path with @ version query is not supported in goc") + case strings.IndexByte(pattern, '@') > 0: + return "", fmt.Errorf("import path with @ version query is not supported in goc") - case isMetaPackage(pattern): - return "", fmt.Errorf("`std`, `cmd`, `all` import path is not supported by goc") + case isMetaPackage(pattern): + return "", fmt.Errorf("`std`, `cmd`, `all` import path is not supported by goc") - default: // 到这一步认为 pattern 是相对路径或者绝对路径 - dir1, err := filepath.Abs(pattern) - if err != nil { - return "", fmt.Errorf("error (%w) get directory from the import path: %v", err, pattern) - } - if _, err := os.Stat(dir1); err != nil { - return "", fmt.Errorf("error (%w) get directory from the import path: %v", err, pattern) - } + default: // 到这一步认为 pattern 是相对路径或者绝对路径 + dir1, err := filepath.Abs(pattern) + if err != nil { + return "", fmt.Errorf("error (%w) get directory from the import path: %v", err, pattern) + } + if _, err := os.Stat(dir1); err != nil { + return "", fmt.Errorf("error (%w) get directory from the import path: %v", err, pattern) + } - return dir1, nil - } + return dir1, nil + } } // isLocalImport reports whether the import path is // a local import path, like ".", "..", "./foo", or "../foo" func isLocalImport(path string) bool { - return path == "." || path == ".." || - strings.HasPrefix(path, "./") || strings.HasPrefix(path, "../") + return path == "." || path == ".." || + strings.HasPrefix(path, "./") || strings.HasPrefix(path, "../") } // isMetaPackage checks if the name is a reserved package name func isMetaPackage(name string) bool { - return name == "std" || name == "cmd" || name == "all" + return name == "std" || name == "cmd" || name == "all" } // find direct path of current project which contains go.mod func findModuleRoot(dir string) string { - dir = filepath.Clean(dir) + dir = filepath.Clean(dir) - // look for enclosing go.mod - for { - if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() { - return dir - } - d := filepath.Dir(dir) - if d == dir { - break - } - dir = d - } + // look for enclosing go.mod + for { + if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() { + return dir + } + d := filepath.Dir(dir) + if d == dir { + break + } + dir = d + } - return "" + return "" } diff --git a/pkg/build/goenv.go b/pkg/build/goenv.go index fa5dbf0..e50d39b 100644 --- a/pkg/build/goenv.go +++ b/pkg/build/goenv.go @@ -23,7 +23,7 @@ import ( "path/filepath" "strings" - "github.com/RickLeee/goc/v2/pkg/log" + "github.com/ar0c/goc/v2/pkg/log" ) // readProjectMetaInfo reads all meta informations of the corresponding project diff --git a/pkg/build/inject.go b/pkg/build/inject.go index 5d020d5..1dfacb3 100644 --- a/pkg/build/inject.go +++ b/pkg/build/inject.go @@ -21,9 +21,9 @@ import ( "path/filepath" "strings" - "github.com/RickLeee/goc/v2/pkg/build/internal/tool" - "github.com/RickLeee/goc/v2/pkg/build/internal/websocket" - "github.com/RickLeee/goc/v2/pkg/log" + "github.com/ar0c/goc/v2/pkg/build/internal/tool" + "github.com/ar0c/goc/v2/pkg/build/internal/websocket" + "github.com/ar0c/goc/v2/pkg/log" ) // Inject injects cover variables for all the .go files in the target directory diff --git a/pkg/build/install.go b/pkg/build/install.go index 086566a..e9e576f 100644 --- a/pkg/build/install.go +++ b/pkg/build/install.go @@ -14,14 +14,14 @@ package build import ( - "os" - "os/exec" + "os" + "os/exec" - "github.com/RickLeee/goc/v2/pkg/log" + "github.com/ar0c/goc/v2/pkg/log" ) func NewInstall(opts ...gocOption) *Build { - return NewBuild(opts...) + return NewBuild(opts...) } // Install starts go install @@ -30,51 +30,51 @@ func NewInstall(opts ...gocOption) *Build { // 2. inject cover variables and functions into the project, // 3. install the project in temp. func (b *Build) Install() { - // 1. 拷贝至临时目录 - b.copyProjectToTmp() - defer b.clean() + // 1. 拷贝至临时目录 + b.copyProjectToTmp() + defer b.clean() - log.Donef("project copied to temporary directory") + log.Donef("project copied to temporary directory") - // 2. update go.mod file if needed - b.updateGoModFile() - // 3. inject cover vars - b.Inject() + // 2. update go.mod file if needed + b.updateGoModFile() + // 3. inject cover vars + b.Inject() - if b.IsVendorMod && b.IsModEdit { - b.reVendor() - } + if b.IsVendorMod && b.IsModEdit { + b.reVendor() + } - // 4. install in the temp project - b.doInstallInTemp() + // 4. install in the temp project + b.doInstallInTemp() } func (b *Build) doInstallInTemp() { - log.StartWait("installing the injected project") + log.StartWait("installing the injected project") - goflags := b.Goflags + goflags := b.Goflags - pacakges := b.Packages + pacakges := b.Packages - goflags = append(goflags, pacakges...) + goflags = append(goflags, pacakges...) - args := []string{"install"} - args = append(args, goflags...) - // go 命令行由 go install [build flags] [packages] 组成 - cmd := exec.Command("go", args...) - cmd.Dir = b.TmpWd - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr + args := []string{"install"} + args = append(args, goflags...) + // go 命令行由 go install [build flags] [packages] 组成 + cmd := exec.Command("go", args...) + cmd.Dir = b.TmpWd + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr - log.Infof("go install cmd is: %v, in path [%v]", cmd.Args, cmd.Dir) - if err := cmd.Start(); err != nil { - log.Fatalf("fail to execute go install: %v", err) - } - if err := cmd.Wait(); err != nil { - log.Fatalf("fail to execute go install: %v", err) - } + log.Infof("go install cmd is: %v, in path [%v]", cmd.Args, cmd.Dir) + if err := cmd.Start(); err != nil { + log.Fatalf("fail to execute go install: %v", err) + } + if err := cmd.Wait(); err != nil { + log.Fatalf("fail to execute go install: %v", err) + } - // done - log.StopWait() - log.Donef("go install done") + // done + log.StopWait() + log.Donef("go install done") } diff --git a/pkg/build/internal/tool/cover.go b/pkg/build/internal/tool/cover.go index 4bfb090..cbe8e83 100644 --- a/pkg/build/internal/tool/cover.go +++ b/pkg/build/internal/tool/cover.go @@ -5,22 +5,22 @@ package tool import ( - "bytes" - "path" + "bytes" + "path" - // "flag" - "fmt" - "go/ast" - "go/parser" - "go/token" - "io" - "io/ioutil" - "os" - "sort" + // "flag" + "fmt" + "go/ast" + "go/parser" + "go/token" + "io" + "io/ioutil" + "os" + "sort" - "github.com/RickLeee/goc/v2/pkg/log" // QINIU - // "cmd/internal/edit" - // "cmd/internal/objabi" + "github.com/ar0c/goc/v2/pkg/log" // QINIU + // "cmd/internal/edit" + // "cmd/internal/objabi" ) // const usageMessage = "" + @@ -63,8 +63,8 @@ import ( var counterStmt func(*File, string) string const ( - atomicPackagePath = "sync/atomic" - atomicPackageName = "_cover_atomic_" + atomicPackagePath = "sync/atomic" + atomicPackageName = "_cover_atomic_" ) // func main() { @@ -149,24 +149,24 @@ const ( // Note: Our definition of basic block is based on control structures; we don't break // apart && and ||. We could but it doesn't seem important enough to bother. type Block struct { - startByte token.Pos - endByte token.Pos - numStmt int + startByte token.Pos + endByte token.Pos + numStmt int } // File is a wrapper for the state of a file used in the parser. // The basic parse tree walker is a method of this type. type File struct { - fset *token.FileSet - name string // Name of file. - astFile *ast.File - blocks []Block - content []byte - edit *Buffer // QINIU - varVar string // QINIU - mode string // QINIU - importpathFileName string // QINIU, importpath + filename - random string // QINIU, random == tmp dir name + fset *token.FileSet + name string // Name of file. + astFile *ast.File + blocks []Block + content []byte + edit *Buffer // QINIU + varVar string // QINIU + mode string // QINIU + importpathFileName string // QINIU, importpath + filename + random string // QINIU, random == tmp dir name } // findText finds text in the original source, starting at pos. @@ -174,133 +174,133 @@ type File struct { // handle quoted strings. // It returns a byte offset within f.src. func (f *File) findText(pos token.Pos, text string) int { - b := []byte(text) - start := f.offset(pos) - i := start - s := f.content - for i < len(s) { - if bytes.HasPrefix(s[i:], b) { - return i - } - if i+2 <= len(s) && s[i] == '/' && s[i+1] == '/' { - for i < len(s) && s[i] != '\n' { - i++ - } - continue - } - if i+2 <= len(s) && s[i] == '/' && s[i+1] == '*' { - for i += 2; ; i++ { - if i+2 > len(s) { - return 0 - } - if s[i] == '*' && s[i+1] == '/' { - i += 2 - break - } - } - continue - } - i++ - } - return -1 + b := []byte(text) + start := f.offset(pos) + i := start + s := f.content + for i < len(s) { + if bytes.HasPrefix(s[i:], b) { + return i + } + if i+2 <= len(s) && s[i] == '/' && s[i+1] == '/' { + for i < len(s) && s[i] != '\n' { + i++ + } + continue + } + if i+2 <= len(s) && s[i] == '/' && s[i+1] == '*' { + for i += 2; ; i++ { + if i+2 > len(s) { + return 0 + } + if s[i] == '*' && s[i+1] == '/' { + i += 2 + break + } + } + continue + } + i++ + } + return -1 } // Visit implements the ast.Visitor interface. func (f *File) Visit(node ast.Node) ast.Visitor { - switch n := node.(type) { - case *ast.BlockStmt: - // If it's a switch or select, the body is a list of case clauses; don't tag the block itself. - if len(n.List) > 0 { - switch n.List[0].(type) { - case *ast.CaseClause: // switch - for _, n := range n.List { - clause := n.(*ast.CaseClause) - f.addCounters(clause.Colon+1, clause.Colon+1, clause.End(), clause.Body, false) - } - return f - case *ast.CommClause: // select - for _, n := range n.List { - clause := n.(*ast.CommClause) - f.addCounters(clause.Colon+1, clause.Colon+1, clause.End(), clause.Body, false) - } - return f - } - } - f.addCounters(n.Lbrace, n.Lbrace+1, n.Rbrace+1, n.List, true) // +1 to step past closing brace. - case *ast.IfStmt: - if n.Init != nil { - ast.Walk(f, n.Init) - } - ast.Walk(f, n.Cond) - ast.Walk(f, n.Body) - if n.Else == nil { - return nil - } - // The elses are special, because if we have - // if x { - // } else if y { - // } - // we want to cover the "if y". To do this, we need a place to drop the counter, - // so we add a hidden block: - // if x { - // } else { - // if y { - // } - // } - elseOffset := f.findText(n.Body.End(), "else") - if elseOffset < 0 { - panic("lost else") - } - f.edit.Insert(elseOffset+4, "{") - f.edit.Insert(f.offset(n.Else.End()), "}") + switch n := node.(type) { + case *ast.BlockStmt: + // If it's a switch or select, the body is a list of case clauses; don't tag the block itself. + if len(n.List) > 0 { + switch n.List[0].(type) { + case *ast.CaseClause: // switch + for _, n := range n.List { + clause := n.(*ast.CaseClause) + f.addCounters(clause.Colon+1, clause.Colon+1, clause.End(), clause.Body, false) + } + return f + case *ast.CommClause: // select + for _, n := range n.List { + clause := n.(*ast.CommClause) + f.addCounters(clause.Colon+1, clause.Colon+1, clause.End(), clause.Body, false) + } + return f + } + } + f.addCounters(n.Lbrace, n.Lbrace+1, n.Rbrace+1, n.List, true) // +1 to step past closing brace. + case *ast.IfStmt: + if n.Init != nil { + ast.Walk(f, n.Init) + } + ast.Walk(f, n.Cond) + ast.Walk(f, n.Body) + if n.Else == nil { + return nil + } + // The elses are special, because if we have + // if x { + // } else if y { + // } + // we want to cover the "if y". To do this, we need a place to drop the counter, + // so we add a hidden block: + // if x { + // } else { + // if y { + // } + // } + elseOffset := f.findText(n.Body.End(), "else") + if elseOffset < 0 { + panic("lost else") + } + f.edit.Insert(elseOffset+4, "{") + f.edit.Insert(f.offset(n.Else.End()), "}") - // We just created a block, now walk it. - // Adjust the position of the new block to start after - // the "else". That will cause it to follow the "{" - // we inserted above. - pos := f.fset.File(n.Body.End()).Pos(elseOffset + 4) - switch stmt := n.Else.(type) { - case *ast.IfStmt: - block := &ast.BlockStmt{ - Lbrace: pos, - List: []ast.Stmt{stmt}, - Rbrace: stmt.End(), - } - n.Else = block - case *ast.BlockStmt: - stmt.Lbrace = pos - default: - panic("unexpected node type in if") - } - ast.Walk(f, n.Else) - return nil - case *ast.SelectStmt: - // Don't annotate an empty select - creates a syntax error. - if n.Body == nil || len(n.Body.List) == 0 { - return nil - } - case *ast.SwitchStmt: - // Don't annotate an empty switch - creates a syntax error. - if n.Body == nil || len(n.Body.List) == 0 { - if n.Init != nil { - ast.Walk(f, n.Init) - } - if n.Tag != nil { - ast.Walk(f, n.Tag) - } - return nil - } - case *ast.TypeSwitchStmt: - // Don't annotate an empty type switch - creates a syntax error. - if n.Body == nil || len(n.Body.List) == 0 { - if n.Init != nil { - ast.Walk(f, n.Init) - } - ast.Walk(f, n.Assign) - return nil - } - } - return f + // We just created a block, now walk it. + // Adjust the position of the new block to start after + // the "else". That will cause it to follow the "{" + // we inserted above. + pos := f.fset.File(n.Body.End()).Pos(elseOffset + 4) + switch stmt := n.Else.(type) { + case *ast.IfStmt: + block := &ast.BlockStmt{ + Lbrace: pos, + List: []ast.Stmt{stmt}, + Rbrace: stmt.End(), + } + n.Else = block + case *ast.BlockStmt: + stmt.Lbrace = pos + default: + panic("unexpected node type in if") + } + ast.Walk(f, n.Else) + return nil + case *ast.SelectStmt: + // Don't annotate an empty select - creates a syntax error. + if n.Body == nil || len(n.Body.List) == 0 { + return nil + } + case *ast.SwitchStmt: + // Don't annotate an empty switch - creates a syntax error. + if n.Body == nil || len(n.Body.List) == 0 { + if n.Init != nil { + ast.Walk(f, n.Init) + } + if n.Tag != nil { + ast.Walk(f, n.Tag) + } + return nil + } + case *ast.TypeSwitchStmt: + // Don't annotate an empty type switch - creates a syntax error. + if n.Body == nil || len(n.Body.List) == 0 { + if n.Init != nil { + ast.Walk(f, n.Init) + } + ast.Walk(f, n.Assign) + return nil + } + } + return f } // QINIU @@ -309,126 +309,126 @@ func (f *File) Visit(node ast.Node) ast.Visitor { // 2. return the cover variables declarations as plain string // original dec: func annotate(name string) { func Annotate(name string, mode string, varVar string, importpathFilename string, globalCoverVarImportPath string) string { - // QINIU - switch mode { - case "set": - counterStmt = setCounterStmt - case "count": - counterStmt = incCounterStmt - case "atomic": - counterStmt = atomicCounterStmt - case "watch": - counterStmt = watchCounterStmt - default: - counterStmt = incCounterStmt - } + // QINIU + switch mode { + case "set": + counterStmt = setCounterStmt + case "count": + counterStmt = incCounterStmt + case "atomic": + counterStmt = atomicCounterStmt + case "watch": + counterStmt = watchCounterStmt + default: + counterStmt = incCounterStmt + } - fset := token.NewFileSet() - content, err := ioutil.ReadFile(name) - if err != nil { - log.Fatalf("cover: %s: %s", name, err) - } - parsedFile, err := parser.ParseFile(fset, name, content, parser.ParseComments) - if err != nil { - log.Fatalf("cover: %s: %s", name, err) - } + fset := token.NewFileSet() + content, err := ioutil.ReadFile(name) + if err != nil { + log.Fatalf("cover: %s: %s", name, err) + } + parsedFile, err := parser.ParseFile(fset, name, content, parser.ParseComments) + if err != nil { + log.Fatalf("cover: %s: %s", name, err) + } - file := &File{ - fset: fset, - name: name, - content: content, - edit: NewBuffer(content), // QINIU - astFile: parsedFile, - varVar: varVar, // QINIU - mode: mode, // QINIU - importpathFileName: importpathFilename, // QINIU - random: path.Base(globalCoverVarImportPath), // QINIU - } + file := &File{ + fset: fset, + name: name, + content: content, + edit: NewBuffer(content), // QINIU + astFile: parsedFile, + varVar: varVar, // QINIU + mode: mode, // QINIU + importpathFileName: importpathFilename, // QINIU + random: path.Base(globalCoverVarImportPath), // QINIU + } - ast.Walk(file, file.astFile) - newContent := file.edit.Bytes() + ast.Walk(file, file.astFile) + newContent := file.edit.Bytes() - if bytes.Equal(content, newContent) { - log.Debugf("no cover var injected for: ", name) - } else { - // reback to the beginning - file.astFile, _ = parser.ParseFile(fset, name, content, parser.ParseComments) - file.edit = NewBuffer(newContent) - // add global cover variables import path - file.edit.Insert(file.offset(file.astFile.Name.End()), - fmt.Sprintf("; import %s %q", ".", globalCoverVarImportPath)) + if bytes.Equal(content, newContent) { + log.Debugf("no cover var injected for: ", name) + } else { + // reback to the beginning + file.astFile, _ = parser.ParseFile(fset, name, content, parser.ParseComments) + file.edit = NewBuffer(newContent) + // add global cover variables import path + file.edit.Insert(file.offset(file.astFile.Name.End()), + fmt.Sprintf("; import %s %q", ".", globalCoverVarImportPath)) - if mode == "atomic" { - // Add import of sync/atomic immediately after package clause. - // We do this even if there is an existing import, because the - // existing import may be shadowed at any given place we want - // to refer to it, and our name (_cover_atomic_) is less likely to - // be shadowed. - file.edit.Insert(file.offset(file.astFile.Name.End()), - fmt.Sprintf("; import %s %q", atomicPackageName, atomicPackagePath)) - } + if mode == "atomic" { + // Add import of sync/atomic immediately after package clause. + // We do this even if there is an existing import, because the + // existing import may be shadowed at any given place we want + // to refer to it, and our name (_cover_atomic_) is less likely to + // be shadowed. + file.edit.Insert(file.offset(file.astFile.Name.End()), + fmt.Sprintf("; import %s %q", atomicPackageName, atomicPackagePath)) + } - newContent = file.edit.Bytes() - } + newContent = file.edit.Bytes() + } - // fd := os.Stdout - // if *output != "" { - // var err error - // fd, err = os.Create(*output) - // if err != nil { - // log.Fatalf("cover: %s", err) - // } - // } - fd, err := os.Create(name) - if err != nil { - log.Fatalf("cover: %s", err) - } - defer fd.Close() + // fd := os.Stdout + // if *output != "" { + // var err error + // fd, err = os.Create(*output) + // if err != nil { + // log.Fatalf("cover: %s", err) + // } + // } + fd, err := os.Create(name) + if err != nil { + log.Fatalf("cover: %s", err) + } + defer fd.Close() - fmt.Fprintf(fd, "//line %s:1\n", name) - _, err = fd.Write(newContent) - if err != nil { - log.Fatalf("cover: %s", err) - } + fmt.Fprintf(fd, "//line %s:1\n", name) + _, err = fd.Write(newContent) + if err != nil { + log.Fatalf("cover: %s", err) + } - // After printing the source tree, add some declarations for the counters etc. - // We could do this by adding to the tree, but it's easier just to print the text. + // After printing the source tree, add some declarations for the counters etc. + // We could do this by adding to the tree, but it's easier just to print the text. - // QINIU - // declarations only print to string - // we will write all declarations into a single file - declBuf := bytes.NewBufferString("") - file.addVariables(declBuf) - return declBuf.String() + // QINIU + // declarations only print to string + // we will write all declarations into a single file + declBuf := bytes.NewBufferString("") + file.addVariables(declBuf) + return declBuf.String() } // setCounterStmt returns the expression: __count[23] = 1. func setCounterStmt(f *File, counter string) string { - return fmt.Sprintf("%s = 1", counter) + return fmt.Sprintf("%s = 1", counter) } // incCounterStmt returns the expression: __count[23]++. func incCounterStmt(f *File, counter string) string { - return fmt.Sprintf("%s++", counter) + return fmt.Sprintf("%s++", counter) } // atomicCounterStmt returns the expression: atomic.AddUint32(&__count[23], 1) func atomicCounterStmt(f *File, counter string) string { - return fmt.Sprintf("%s.AddUint32(&%s, 1)", atomicPackageName, counter) + return fmt.Sprintf("%s.AddUint32(&%s, 1)", atomicPackageName, counter) } // watchCounterStmt returns the expression: __count[23]++;UploadCoverChangeEvent(blockname, pos[:], index) func watchCounterStmt(f *File, counter string) string { - index := len(f.blocks) - return fmt.Sprintf("%s++; UploadCoverChangeEvent_%v(%s.BlockName, %s.Pos[:], %v, %s.NumStmt[%v])", counter, f.random, f.varVar, f.varVar, index, f.varVar, index) + index := len(f.blocks) + return fmt.Sprintf("%s++; UploadCoverChangeEvent_%v(%s.BlockName, %s.Pos[:], %v, %s.NumStmt[%v])", counter, f.random, f.varVar, f.varVar, index, f.varVar, index) } // QINIU // newCounter creates a new counter expression of the appropriate form. func (f *File) newCounter(start, end token.Pos, numStmt int) string { - stmt := counterStmt(f, fmt.Sprintf("%s.Count[%d]", f.varVar, len(f.blocks))) - f.blocks = append(f.blocks, Block{start, end, numStmt}) - return stmt + stmt := counterStmt(f, fmt.Sprintf("%s.Count[%d]", f.varVar, len(f.blocks))) + f.blocks = append(f.blocks, Block{start, end, numStmt}) + return stmt } // addCounters takes a list of statements and adds counters to the beginning of @@ -444,68 +444,68 @@ func (f *File) newCounter(start, end token.Pos, numStmt int) string { // will be visited in a separate call. // TODO: Nested simple blocks get unnecessary (but correct) counters func (f *File) addCounters(pos, insertPos, blockEnd token.Pos, list []ast.Stmt, extendToClosingBrace bool) { - // Special case: make sure we add a counter to an empty block. Can't do this below - // or we will add a counter to an empty statement list after, say, a return statement. - if len(list) == 0 { - f.edit.Insert(f.offset(insertPos), f.newCounter(insertPos, blockEnd, 0)+";") - return - } - // Make a copy of the list, as we may mutate it and should leave the - // existing list intact. - list = append([]ast.Stmt(nil), list...) - // We have a block (statement list), but it may have several basic blocks due to the - // appearance of statements that affect the flow of control. - for { - // Find first statement that affects flow of control (break, continue, if, etc.). - // It will be the last statement of this basic block. - var last int - end := blockEnd - for last = 0; last < len(list); last++ { - stmt := list[last] - end = f.statementBoundary(stmt) - if f.endsBasicSourceBlock(stmt) { - // If it is a labeled statement, we need to place a counter between - // the label and its statement because it may be the target of a goto - // and thus start a basic block. That is, given - // foo: stmt - // we need to create - // foo: ; stmt - // and mark the label as a block-terminating statement. - // The result will then be - // foo: COUNTER[n]++; stmt - // However, we can't do this if the labeled statement is already - // a control statement, such as a labeled for. - if label, isLabel := stmt.(*ast.LabeledStmt); isLabel && !f.isControl(label.Stmt) { - newLabel := *label - newLabel.Stmt = &ast.EmptyStmt{ - Semicolon: label.Stmt.Pos(), - Implicit: true, - } - end = label.Pos() // Previous block ends before the label. - list[last] = &newLabel - // Open a gap and drop in the old statement, now without a label. - list = append(list, nil) - copy(list[last+1:], list[last:]) - list[last+1] = label.Stmt - } - last++ - extendToClosingBrace = false // Block is broken up now. - break - } - } - if extendToClosingBrace { - end = blockEnd - } - if pos != end { // Can have no source to cover if e.g. blocks abut. - f.edit.Insert(f.offset(insertPos), f.newCounter(pos, end, last)+";") - } - list = list[last:] - if len(list) == 0 { - break - } - pos = list[0].Pos() - insertPos = pos - } + // Special case: make sure we add a counter to an empty block. Can't do this below + // or we will add a counter to an empty statement list after, say, a return statement. + if len(list) == 0 { + f.edit.Insert(f.offset(insertPos), f.newCounter(insertPos, blockEnd, 0)+";") + return + } + // Make a copy of the list, as we may mutate it and should leave the + // existing list intact. + list = append([]ast.Stmt(nil), list...) + // We have a block (statement list), but it may have several basic blocks due to the + // appearance of statements that affect the flow of control. + for { + // Find first statement that affects flow of control (break, continue, if, etc.). + // It will be the last statement of this basic block. + var last int + end := blockEnd + for last = 0; last < len(list); last++ { + stmt := list[last] + end = f.statementBoundary(stmt) + if f.endsBasicSourceBlock(stmt) { + // If it is a labeled statement, we need to place a counter between + // the label and its statement because it may be the target of a goto + // and thus start a basic block. That is, given + // foo: stmt + // we need to create + // foo: ; stmt + // and mark the label as a block-terminating statement. + // The result will then be + // foo: COUNTER[n]++; stmt + // However, we can't do this if the labeled statement is already + // a control statement, such as a labeled for. + if label, isLabel := stmt.(*ast.LabeledStmt); isLabel && !f.isControl(label.Stmt) { + newLabel := *label + newLabel.Stmt = &ast.EmptyStmt{ + Semicolon: label.Stmt.Pos(), + Implicit: true, + } + end = label.Pos() // Previous block ends before the label. + list[last] = &newLabel + // Open a gap and drop in the old statement, now without a label. + list = append(list, nil) + copy(list[last+1:], list[last:]) + list[last+1] = label.Stmt + } + last++ + extendToClosingBrace = false // Block is broken up now. + break + } + } + if extendToClosingBrace { + end = blockEnd + } + if pos != end { // Can have no source to cover if e.g. blocks abut. + f.edit.Insert(f.offset(insertPos), f.newCounter(pos, end, last)+";") + } + list = list[last:] + if len(list) == 0 { + break + } + pos = list[0].Pos() + insertPos = pos + } } // hasFuncLiteral reports the existence and position of the first func literal @@ -514,131 +514,131 @@ func (f *File) addCounters(pos, insertPos, blockEnd token.Pos, list []ast.Stmt, // Therefore we draw a line at the start of the body of the first function literal we find. // TODO: what if there's more than one? Probably doesn't matter much. func hasFuncLiteral(n ast.Node) (bool, token.Pos) { - if n == nil { - return false, 0 - } - var literal funcLitFinder - ast.Walk(&literal, n) - return literal.found(), token.Pos(literal) + if n == nil { + return false, 0 + } + var literal funcLitFinder + ast.Walk(&literal, n) + return literal.found(), token.Pos(literal) } // statementBoundary finds the location in s that terminates the current basic // block in the source. func (f *File) statementBoundary(s ast.Stmt) token.Pos { - // Control flow statements are easy. - switch s := s.(type) { - case *ast.BlockStmt: - // Treat blocks like basic blocks to avoid overlapping counters. - return s.Lbrace - case *ast.IfStmt: - found, pos := hasFuncLiteral(s.Init) - if found { - return pos - } - found, pos = hasFuncLiteral(s.Cond) - if found { - return pos - } - return s.Body.Lbrace - case *ast.ForStmt: - found, pos := hasFuncLiteral(s.Init) - if found { - return pos - } - found, pos = hasFuncLiteral(s.Cond) - if found { - return pos - } - found, pos = hasFuncLiteral(s.Post) - if found { - return pos - } - return s.Body.Lbrace - case *ast.LabeledStmt: - return f.statementBoundary(s.Stmt) - case *ast.RangeStmt: - found, pos := hasFuncLiteral(s.X) - if found { - return pos - } - return s.Body.Lbrace - case *ast.SwitchStmt: - found, pos := hasFuncLiteral(s.Init) - if found { - return pos - } - found, pos = hasFuncLiteral(s.Tag) - if found { - return pos - } - return s.Body.Lbrace - case *ast.SelectStmt: - return s.Body.Lbrace - case *ast.TypeSwitchStmt: - found, pos := hasFuncLiteral(s.Init) - if found { - return pos - } - return s.Body.Lbrace - } - // If not a control flow statement, it is a declaration, expression, call, etc. and it may have a function literal. - // If it does, that's tricky because we want to exclude the body of the function from this block. - // Draw a line at the start of the body of the first function literal we find. - // TODO: what if there's more than one? Probably doesn't matter much. - found, pos := hasFuncLiteral(s) - if found { - return pos - } - return s.End() + // Control flow statements are easy. + switch s := s.(type) { + case *ast.BlockStmt: + // Treat blocks like basic blocks to avoid overlapping counters. + return s.Lbrace + case *ast.IfStmt: + found, pos := hasFuncLiteral(s.Init) + if found { + return pos + } + found, pos = hasFuncLiteral(s.Cond) + if found { + return pos + } + return s.Body.Lbrace + case *ast.ForStmt: + found, pos := hasFuncLiteral(s.Init) + if found { + return pos + } + found, pos = hasFuncLiteral(s.Cond) + if found { + return pos + } + found, pos = hasFuncLiteral(s.Post) + if found { + return pos + } + return s.Body.Lbrace + case *ast.LabeledStmt: + return f.statementBoundary(s.Stmt) + case *ast.RangeStmt: + found, pos := hasFuncLiteral(s.X) + if found { + return pos + } + return s.Body.Lbrace + case *ast.SwitchStmt: + found, pos := hasFuncLiteral(s.Init) + if found { + return pos + } + found, pos = hasFuncLiteral(s.Tag) + if found { + return pos + } + return s.Body.Lbrace + case *ast.SelectStmt: + return s.Body.Lbrace + case *ast.TypeSwitchStmt: + found, pos := hasFuncLiteral(s.Init) + if found { + return pos + } + return s.Body.Lbrace + } + // If not a control flow statement, it is a declaration, expression, call, etc. and it may have a function literal. + // If it does, that's tricky because we want to exclude the body of the function from this block. + // Draw a line at the start of the body of the first function literal we find. + // TODO: what if there's more than one? Probably doesn't matter much. + found, pos := hasFuncLiteral(s) + if found { + return pos + } + return s.End() } // endsBasicSourceBlock reports whether s changes the flow of control: break, if, etc., // or if it's just problematic, for instance contains a function literal, which will complicate // accounting due to the block-within-an expression. func (f *File) endsBasicSourceBlock(s ast.Stmt) bool { - switch s := s.(type) { - case *ast.BlockStmt: - // Treat blocks like basic blocks to avoid overlapping counters. - return true - case *ast.BranchStmt: - return true - case *ast.ForStmt: - return true - case *ast.IfStmt: - return true - case *ast.LabeledStmt: - return true // A goto may branch here, starting a new basic block. - case *ast.RangeStmt: - return true - case *ast.SwitchStmt: - return true - case *ast.SelectStmt: - return true - case *ast.TypeSwitchStmt: - return true - case *ast.ExprStmt: - // Calls to panic change the flow. - // We really should verify that "panic" is the predefined function, - // but without type checking we can't and the likelihood of it being - // an actual problem is vanishingly small. - if call, ok := s.X.(*ast.CallExpr); ok { - if ident, ok := call.Fun.(*ast.Ident); ok && ident.Name == "panic" && len(call.Args) == 1 { - return true - } - } - } - found, _ := hasFuncLiteral(s) - return found + switch s := s.(type) { + case *ast.BlockStmt: + // Treat blocks like basic blocks to avoid overlapping counters. + return true + case *ast.BranchStmt: + return true + case *ast.ForStmt: + return true + case *ast.IfStmt: + return true + case *ast.LabeledStmt: + return true // A goto may branch here, starting a new basic block. + case *ast.RangeStmt: + return true + case *ast.SwitchStmt: + return true + case *ast.SelectStmt: + return true + case *ast.TypeSwitchStmt: + return true + case *ast.ExprStmt: + // Calls to panic change the flow. + // We really should verify that "panic" is the predefined function, + // but without type checking we can't and the likelihood of it being + // an actual problem is vanishingly small. + if call, ok := s.X.(*ast.CallExpr); ok { + if ident, ok := call.Fun.(*ast.Ident); ok && ident.Name == "panic" && len(call.Args) == 1 { + return true + } + } + } + found, _ := hasFuncLiteral(s) + return found } // isControl reports whether s is a control statement that, if labeled, cannot be // separated from its label. func (f *File) isControl(s ast.Stmt) bool { - switch s.(type) { - case *ast.ForStmt, *ast.RangeStmt, *ast.SwitchStmt, *ast.SelectStmt, *ast.TypeSwitchStmt: - return true - } - return false + switch s.(type) { + case *ast.ForStmt, *ast.RangeStmt, *ast.SwitchStmt, *ast.SelectStmt, *ast.TypeSwitchStmt: + return true + } + return false } // funcLitFinder implements the ast.Visitor pattern to find the location of any @@ -646,26 +646,26 @@ func (f *File) isControl(s ast.Stmt) bool { type funcLitFinder token.Pos func (f *funcLitFinder) Visit(node ast.Node) (w ast.Visitor) { - if f.found() { - return nil // Prune search. - } - switch n := node.(type) { - case *ast.FuncLit: - *f = funcLitFinder(n.Body.Lbrace) - return nil // Prune search. - } - return f + if f.found() { + return nil // Prune search. + } + switch n := node.(type) { + case *ast.FuncLit: + *f = funcLitFinder(n.Body.Lbrace) + return nil // Prune search. + } + return f } func (f *funcLitFinder) found() bool { - return token.Pos(*f) != token.NoPos + return token.Pos(*f) != token.NoPos } // Sort interface for []block1; used for self-check in addVariables. type block1 struct { - Block - index int + Block + index int } type blockSlice []block1 @@ -676,83 +676,83 @@ func (b blockSlice) Swap(i, j int) { b[i], b[j] = b[j], b[i] } // offset translates a token position into a 0-indexed byte offset. func (f *File) offset(pos token.Pos) int { - return f.fset.Position(pos).Offset + return f.fset.Position(pos).Offset } // addVariables adds to the end of the file the declarations to set up the counter and position variables. func (f *File) addVariables(w io.Writer) { - // Self-check: Verify that the instrumented basic blocks are disjoint. - t := make([]block1, len(f.blocks)) - for i := range f.blocks { - t[i].Block = f.blocks[i] - t[i].index = i - } - sort.Sort(blockSlice(t)) - for i := 1; i < len(t); i++ { - if t[i-1].endByte > t[i].startByte { - fmt.Fprintf(os.Stderr, "cover: internal error: block %d overlaps block %d\n", t[i-1].index, t[i].index) - // Note: error message is in byte positions, not token positions. - fmt.Fprintf(os.Stderr, "\t%s:#%d,#%d %s:#%d,#%d\n", - f.name, f.offset(t[i-1].startByte), f.offset(t[i-1].endByte), - f.name, f.offset(t[i].startByte), f.offset(t[i].endByte)) - } - } + // Self-check: Verify that the instrumented basic blocks are disjoint. + t := make([]block1, len(f.blocks)) + for i := range f.blocks { + t[i].Block = f.blocks[i] + t[i].index = i + } + sort.Sort(blockSlice(t)) + for i := 1; i < len(t); i++ { + if t[i-1].endByte > t[i].startByte { + fmt.Fprintf(os.Stderr, "cover: internal error: block %d overlaps block %d\n", t[i-1].index, t[i].index) + // Note: error message is in byte positions, not token positions. + fmt.Fprintf(os.Stderr, "\t%s:#%d,#%d %s:#%d,#%d\n", + f.name, f.offset(t[i-1].startByte), f.offset(t[i-1].endByte), + f.name, f.offset(t[i].startByte), f.offset(t[i].endByte)) + } + } - // Declare the coverage struct as a package-level variable. - fmt.Fprintf(w, "\nvar %s = struct {\n", f.varVar) // QINIU - fmt.Fprintf(w, "\tCount [%d]uint32\n", len(f.blocks)) - fmt.Fprintf(w, "\tPos [3 * %d]uint32\n", len(f.blocks)) - fmt.Fprintf(w, "\tNumStmt [%d]uint16\n", len(f.blocks)) - fmt.Fprintf(w, "\tBlockName string\n") // QINIU - fmt.Fprintf(w, "} {\n") + // Declare the coverage struct as a package-level variable. + fmt.Fprintf(w, "\nvar %s = struct {\n", f.varVar) // QINIU + fmt.Fprintf(w, "\tCount [%d]uint32\n", len(f.blocks)) + fmt.Fprintf(w, "\tPos [3 * %d]uint32\n", len(f.blocks)) + fmt.Fprintf(w, "\tNumStmt [%d]uint16\n", len(f.blocks)) + fmt.Fprintf(w, "\tBlockName string\n") // QINIU + fmt.Fprintf(w, "} {\n") - // 写入 BlockName 初始化 - fmt.Fprintf(w, "\tBlockName: \"%v\",\n", f.importpathFileName) + // 写入 BlockName 初始化 + fmt.Fprintf(w, "\tBlockName: \"%v\",\n", f.importpathFileName) - // Initialize the position array field. - fmt.Fprintf(w, "\tPos: [3 * %d]uint32{\n", len(f.blocks)) + // Initialize the position array field. + fmt.Fprintf(w, "\tPos: [3 * %d]uint32{\n", len(f.blocks)) - // A nice long list of positions. Each position is encoded as follows to reduce size: - // - 32-bit starting line number - // - 32-bit ending line number - // - (16 bit ending column number << 16) | (16-bit starting column number). - for i, block := range f.blocks { - start := f.fset.Position(block.startByte) - end := f.fset.Position(block.endByte) + // A nice long list of positions. Each position is encoded as follows to reduce size: + // - 32-bit starting line number + // - 32-bit ending line number + // - (16 bit ending column number << 16) | (16-bit starting column number). + for i, block := range f.blocks { + start := f.fset.Position(block.startByte) + end := f.fset.Position(block.endByte) - start, end = dedup(start, end) + start, end = dedup(start, end) - fmt.Fprintf(w, "\t\t%d, %d, %#x, // [%d]\n", start.Line, end.Line, (end.Column&0xFFFF)<<16|(start.Column&0xFFFF), i) - } + fmt.Fprintf(w, "\t\t%d, %d, %#x, // [%d]\n", start.Line, end.Line, (end.Column&0xFFFF)<<16|(start.Column&0xFFFF), i) + } - // Close the position array. - fmt.Fprintf(w, "\t},\n") + // Close the position array. + fmt.Fprintf(w, "\t},\n") - // Initialize the position array field. - fmt.Fprintf(w, "\tNumStmt: [%d]uint16{\n", len(f.blocks)) + // Initialize the position array field. + fmt.Fprintf(w, "\tNumStmt: [%d]uint16{\n", len(f.blocks)) - // A nice long list of statements-per-block, so we can give a conventional - // valuation of "percent covered". To save space, it's a 16-bit number, so we - // clamp it if it overflows - won't matter in practice. - for i, block := range f.blocks { - n := block.numStmt - if n > 1<<16-1 { - n = 1<<16 - 1 - } - fmt.Fprintf(w, "\t\t%d, // %d\n", n, i) - } + // A nice long list of statements-per-block, so we can give a conventional + // valuation of "percent covered". To save space, it's a 16-bit number, so we + // clamp it if it overflows - won't matter in practice. + for i, block := range f.blocks { + n := block.numStmt + if n > 1<<16-1 { + n = 1<<16 - 1 + } + fmt.Fprintf(w, "\t\t%d, // %d\n", n, i) + } - // Close the statements-per-block array. - fmt.Fprintf(w, "\t},\n") + // Close the statements-per-block array. + fmt.Fprintf(w, "\t},\n") - // Close the struct initialization. - fmt.Fprintf(w, "}\n") + // Close the struct initialization. + fmt.Fprintf(w, "}\n") - // Emit a reference to the atomic package to avoid - // import and not used error when there's no code in a file. - // if f.mode == "atomic" { // QINIU, no need to import - // fmt.Fprintf(w, "var _ = %s.LoadUint32\n", atomicPackageName) - // } + // Emit a reference to the atomic package to avoid + // import and not used error when there's no code in a file. + // if f.mode == "atomic" { // QINIU, no need to import + // fmt.Fprintf(w, "var _ = %s.LoadUint32\n", atomicPackageName) + // } } // It is possible for positions to repeat when there is a line @@ -764,7 +764,7 @@ func (f *File) addVariables(w io.Writer) { // pos2 is a pair of token.Position values, used as a map key type. type pos2 struct { - p1, p2 token.Position + p1, p2 token.Position } // seenPos2 tracks whether we have seen a token.Position pair. @@ -774,20 +774,20 @@ var seenPos2 = make(map[pos2]bool) // duplicate any existing pair. The returned pair will have the Offset // fields cleared. func dedup(p1, p2 token.Position) (r1, r2 token.Position) { - key := pos2{ - p1: p1, - p2: p2, - } + key := pos2{ + p1: p1, + p2: p2, + } - // We want to ignore the Offset fields in the map, - // since cover uses only file/line/column. - key.p1.Offset = 0 - key.p2.Offset = 0 + // We want to ignore the Offset fields in the map, + // since cover uses only file/line/column. + key.p1.Offset = 0 + key.p2.Offset = 0 - for seenPos2[key] { - key.p2.Column++ - } - seenPos2[key] = true + for seenPos2[key] { + key.p2.Column++ + } + seenPos2[key] = true - return key.p1, key.p2 + return key.p1, key.p2 } diff --git a/pkg/build/internal/websocket/wsdep.go b/pkg/build/internal/websocket/wsdep.go index 591e0a9..8b12aa4 100644 --- a/pkg/build/internal/websocket/wsdep.go +++ b/pkg/build/internal/websocket/wsdep.go @@ -14,14 +14,14 @@ package websocket import ( - "archive/tar" - "bytes" - "embed" - "io" - "os" - "path/filepath" + "archive/tar" + "bytes" + "embed" + "io" + "os" + "path/filepath" - "github.com/RickLeee/goc/v2/pkg/log" + "github.com/ar0c/goc/v2/pkg/log" ) //go:embed websocket.tar @@ -32,48 +32,48 @@ var depTarFile embed.FS // 从 embed 文件系统中解压 websocket.tar 文件,并依次写入临时工程中,作为一个单独的包存在。 // gorrila/websocket 是一个无第三方依赖的库,因此其位置可以随处移动,而不影响自身的编译。 func AddCustomWebsocketDep(customWebsocketPath string) { - data, err := depTarFile.ReadFile("websocket.tar") - if err != nil { - log.Fatalf("cannot find the websocket.tar in the embed file: %v", err) - } + data, err := depTarFile.ReadFile("websocket.tar") + if err != nil { + log.Fatalf("cannot find the websocket.tar in the embed file: %v", err) + } - buf := bytes.NewBuffer(data) - tr := tar.NewReader(buf) - for { - hdr, err := tr.Next() - if err == io.EOF { - break - } - if err != nil { - log.Fatalf("cannot untar the websocket.tar: %v", err) - } + buf := bytes.NewBuffer(data) + tr := tar.NewReader(buf) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + log.Fatalf("cannot untar the websocket.tar: %v", err) + } - fpath := filepath.Join(customWebsocketPath, hdr.Name) - if hdr.FileInfo().IsDir() { - // 处理目录 - err := os.MkdirAll(fpath, hdr.FileInfo().Mode()) - if err != nil { - log.Fatalf("fail to untar the websocket.tar: %v", err) - } - } else { - // 处理文件 - fdir := filepath.Dir(fpath) - err := os.MkdirAll(fdir, hdr.FileInfo().Mode()) - if err != nil { - log.Fatalf("fail to untar the websocket.tar: %v", err) - } + fpath := filepath.Join(customWebsocketPath, hdr.Name) + if hdr.FileInfo().IsDir() { + // 处理目录 + err := os.MkdirAll(fpath, hdr.FileInfo().Mode()) + if err != nil { + log.Fatalf("fail to untar the websocket.tar: %v", err) + } + } else { + // 处理文件 + fdir := filepath.Dir(fpath) + err := os.MkdirAll(fdir, hdr.FileInfo().Mode()) + if err != nil { + log.Fatalf("fail to untar the websocket.tar: %v", err) + } - f, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, hdr.FileInfo().Mode()) - if err != nil { - log.Fatalf("fail to untar the websocket.tar: %v", err) - } - defer f.Close() + f, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, hdr.FileInfo().Mode()) + if err != nil { + log.Fatalf("fail to untar the websocket.tar: %v", err) + } + defer f.Close() - _, err = io.Copy(f, tr) + _, err = io.Copy(f, tr) - if err != nil { - log.Fatalf("fail to untar the websocket.tar: %v", err) - } - } - } + if err != nil { + log.Fatalf("fail to untar the websocket.tar: %v", err) + } + } + } } diff --git a/pkg/build/run.go b/pkg/build/run.go index ee81db2..b82831c 100644 --- a/pkg/build/run.go +++ b/pkg/build/run.go @@ -14,33 +14,33 @@ package build import ( - "os" - "os/exec" - "os/signal" + "os" + "os/exec" + "os/signal" - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/RickLeee/goc/v2/pkg/server" - "github.com/RickLeee/goc/v2/pkg/server/store" - "github.com/gin-gonic/gin" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/ar0c/goc/v2/pkg/server" + "github.com/ar0c/goc/v2/pkg/server/store" + "github.com/gin-gonic/gin" ) func NewRun(opts ...gocOption) *Build { - b := &Build{} + b := &Build{} - for _, opt := range opts { - opt(b) - } + for _, opt := range opts { + opt(b) + } - // 1. 解析 goc 命令行和 go 命令行 - b.runCmdArgsParse() - // 2. 解析 go 包位置 - // b.getPackagesDir() - // 3. 读取工程元信息:go.mod, pkgs list ... - b.readProjectMetaInfo() - // 4. 展示元信息 - b.displayProjectMetaInfo() + // 1. 解析 goc 命令行和 go 命令行 + b.runCmdArgsParse() + // 2. 解析 go 包位置 + // b.getPackagesDir() + // 3. 读取工程元信息:go.mod, pkgs list ... + b.readProjectMetaInfo() + // 4. 展示元信息 + b.displayProjectMetaInfo() - return b + return b } // Run starts go run @@ -49,58 +49,58 @@ func NewRun(opts ...gocOption) *Build { // 2. inject cover variables and functions into the project, // 3. run the project in temp. func (b *Build) Run() { - // 1. 拷贝至临时目录 - b.copyProjectToTmp() - defer b.clean() + // 1. 拷贝至临时目录 + b.copyProjectToTmp() + defer b.clean() - log.Donef("project copied to temporary directory") + log.Donef("project copied to temporary directory") - // 2. update go.mod file if needed - b.updateGoModFile() - // 3. inject cover vars - b.Inject() + // 2. update go.mod file if needed + b.updateGoModFile() + // 3. inject cover vars + b.Inject() - if b.IsVendorMod && b.IsModEdit { - b.reVendor() - } + if b.IsVendorMod && b.IsModEdit { + b.reVendor() + } - // 4. run in the temp project - go func() { - ch := make(chan os.Signal, 1) - signal.Notify(ch, os.Interrupt) - <-ch - b.clean() - }() - b.doRunInTemp() + // 4. run in the temp project + go func() { + ch := make(chan os.Signal, 1) + signal.Notify(ch, os.Interrupt) + <-ch + b.clean() + }() + b.doRunInTemp() } func (b *Build) doRunInTemp() { - log.Infof("running the injected project") + log.Infof("running the injected project") - s := store.NewFakeStore() - go func() { - gin.SetMode(gin.ReleaseMode) - err := server.RunGocServerUntilExit(b.Host, s) - if err != nil { - log.Fatalf("goc server fail to run: %v", err) - } - }() + s := store.NewFakeStore() + go func() { + gin.SetMode(gin.ReleaseMode) + err := server.RunGocServerUntilExit(b.Host, s) + if err != nil { + log.Fatalf("goc server fail to run: %v", err) + } + }() - args := []string{"run"} - args = append(args, b.GoArgs...) - cmd := exec.Command("go", args...) - cmd.Dir = b.TmpWd - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr + args := []string{"run"} + args = append(args, b.GoArgs...) + cmd := exec.Command("go", args...) + cmd.Dir = b.TmpWd + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr - log.Infof("go run cmd is: %v, in path [%v]", nicePrintArgs(cmd.Args), cmd.Dir) - if err := cmd.Start(); err != nil { - log.Errorf("fail to execute go run: %v", err) - } - if err := cmd.Wait(); err != nil { - log.Errorf("fail to execute go run: %v", err) - } + log.Infof("go run cmd is: %v, in path [%v]", nicePrintArgs(cmd.Args), cmd.Dir) + if err := cmd.Start(); err != nil { + log.Errorf("fail to execute go run: %v", err) + } + if err := cmd.Wait(); err != nil { + log.Errorf("fail to execute go run: %v", err) + } - // done - log.Donef("go run done") + // done + log.Donef("go run done") } diff --git a/pkg/build/tmpfolder.go b/pkg/build/tmpfolder.go index 844469c..1bc2df1 100644 --- a/pkg/build/tmpfolder.go +++ b/pkg/build/tmpfolder.go @@ -14,79 +14,79 @@ package build import ( - "crypto/sha256" - "fmt" - "io/ioutil" - "os" - "path/filepath" - "strings" + "crypto/sha256" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/tongjingran/copy" - "golang.org/x/mod/modfile" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/tongjingran/copy" + "golang.org/x/mod/modfile" ) // copyProjectToTmp copies project files to the temporary directory // // It will ignore .git and irregular files, only copy source(text) files func (b *Build) copyProjectToTmp() { - curProject := b.CurModProjectDir - tmpProject := b.TmpModProjectDir + curProject := b.CurModProjectDir + tmpProject := b.TmpModProjectDir - if _, err := os.Stat(tmpProject); !os.IsNotExist(err) { - log.Infof("find previous temporary directory, delete") - err := os.RemoveAll(tmpProject) - if err != nil { - log.Fatalf("fail to remove preivous temporary directory: %v", err) - } - } + if _, err := os.Stat(tmpProject); !os.IsNotExist(err) { + log.Infof("find previous temporary directory, delete") + err := os.RemoveAll(tmpProject) + if err != nil { + log.Fatalf("fail to remove preivous temporary directory: %v", err) + } + } - log.StartWait("coping project") - err := os.MkdirAll(tmpProject, os.ModePerm) - if err != nil { - log.Fatalf("fail to create temporary directory: %v", err) - } + log.StartWait("coping project") + err := os.MkdirAll(tmpProject, os.ModePerm) + if err != nil { + log.Fatalf("fail to create temporary directory: %v", err) + } - // copy - if err := copy.Copy(curProject, tmpProject, copy.Options{Skip: skipCopy}); err != nil { - log.Fatalf("fail to copy the folder from %v to %v, the err: %v", curProject, tmpProject, err) - } + // copy + if err := copy.Copy(curProject, tmpProject, copy.Options{Skip: skipCopy}); err != nil { + log.Fatalf("fail to copy the folder from %v to %v, the err: %v", curProject, tmpProject, err) + } - log.StopWait() + log.StopWait() } // TmpFolderName generates a directory name according to the path func TmpFolderName(path string) string { - sum := sha256.Sum256([]byte(path)) - h := fmt.Sprintf("%x", sum[:6]) + sum := sha256.Sum256([]byte(path)) + h := fmt.Sprintf("%x", sum[:6]) - return "gocbuild" + h + return "gocbuild" + h } // skipCopy skip copy .git dir and irregular files func skipCopy(src string, info os.FileInfo) (bool, error) { - irregularModeType := os.ModeNamedPipe | os.ModeSocket | os.ModeDevice | os.ModeCharDevice | os.ModeIrregular - if strings.HasSuffix(src, "/.git") { - log.Debugf("skip .git dir [%s]", src) - return true, nil - } - if info.Mode()&irregularModeType != 0 { - log.Debugf("skip file [%s], the file mode is [%s]", src, info.Mode().String()) - return true, nil - } - return false, nil + irregularModeType := os.ModeNamedPipe | os.ModeSocket | os.ModeDevice | os.ModeCharDevice | os.ModeIrregular + if strings.HasSuffix(src, "/.git") { + log.Debugf("skip .git dir [%s]", src) + return true, nil + } + if info.Mode()&irregularModeType != 0 { + log.Debugf("skip file [%s], the file mode is [%s]", src, info.Mode().String()) + return true, nil + } + return false, nil } // clean clears the temporary project func (b *Build) clean() { - if !b.Debug { - if err := os.RemoveAll(b.TmpModProjectDir); err != nil { - log.Fatalf("fail to delete the temporary project: %v", err) - } - log.Donef("delete the temporary project") - } else { - log.Debugf("--debug is enabled, keep the temporary project") - } + if !b.Debug { + if err := os.RemoveAll(b.TmpModProjectDir); err != nil { + log.Fatalf("fail to delete the temporary project: %v", err) + } + log.Donef("delete the temporary project") + } else { + log.Debugf("--debug is enabled, keep the temporary project") + } } // updateGoModFile rewrites the go.mod file in the temporary directory, @@ -102,50 +102,50 @@ func (b *Build) clean() { // after the project is copied to temporary directory, it should be rewritten as // 'replace github.com/qiniu/bar => /path/to/aa/bb/home/foo/bar' func (b *Build) updateGoModFile() (updateFlag bool, newModFile []byte) { - tempModfile := filepath.Join(b.TmpModProjectDir, "go.mod") - buf, err := ioutil.ReadFile(tempModfile) - if err != nil { - log.Fatalf("cannot find go.mod file in temporary directory: %v", err) - } - oriGoModFile, err := modfile.Parse(tempModfile, buf, nil) - if err != nil { - log.Fatalf("cannot parse go.mod: %v", err) - } + tempModfile := filepath.Join(b.TmpModProjectDir, "go.mod") + buf, err := ioutil.ReadFile(tempModfile) + if err != nil { + log.Fatalf("cannot find go.mod file in temporary directory: %v", err) + } + oriGoModFile, err := modfile.Parse(tempModfile, buf, nil) + if err != nil { + log.Fatalf("cannot parse go.mod: %v", err) + } - updateFlag = false - for index := range oriGoModFile.Replace { - replace := oriGoModFile.Replace[index] - oldPath := replace.Old.Path - oldVersion := replace.Old.Version - newPath := replace.New.Path - newVersion := replace.New.Version - // replace to a local filesystem does not have a version - // absolute path no need to rewrite - if newVersion == "" && !filepath.IsAbs(newPath) { - var absPath string - fullPath := filepath.Join(b.CurModProjectDir, newPath) - absPath, _ = filepath.Abs(fullPath) - // DropReplace & AddReplace will not return error - // so no need to check the error - _ = oriGoModFile.DropReplace(oldPath, oldVersion) - _ = oriGoModFile.AddReplace(oldPath, oldVersion, absPath, newVersion) - updateFlag = true - } - } - oriGoModFile.Cleanup() - // Format will not return error, so ignore the returned error - // func (f *File) Format() ([]byte, error) { - // return Format(f.Syntax), nil - // } - newModFile, _ = oriGoModFile.Format() + updateFlag = false + for index := range oriGoModFile.Replace { + replace := oriGoModFile.Replace[index] + oldPath := replace.Old.Path + oldVersion := replace.Old.Version + newPath := replace.New.Path + newVersion := replace.New.Version + // replace to a local filesystem does not have a version + // absolute path no need to rewrite + if newVersion == "" && !filepath.IsAbs(newPath) { + var absPath string + fullPath := filepath.Join(b.CurModProjectDir, newPath) + absPath, _ = filepath.Abs(fullPath) + // DropReplace & AddReplace will not return error + // so no need to check the error + _ = oriGoModFile.DropReplace(oldPath, oldVersion) + _ = oriGoModFile.AddReplace(oldPath, oldVersion, absPath, newVersion) + updateFlag = true + } + } + oriGoModFile.Cleanup() + // Format will not return error, so ignore the returned error + // func (f *File) Format() ([]byte, error) { + // return Format(f.Syntax), nil + // } + newModFile, _ = oriGoModFile.Format() - if updateFlag { - log.Infof("go.mod needs rewrite") - err := os.WriteFile(tempModfile, newModFile, os.ModePerm) - if err != nil { - log.Fatalf("fail to update go.mod: %v", err) - } - b.IsModEdit = true - } - return + if updateFlag { + log.Infof("go.mod needs rewrite") + err := os.WriteFile(tempModfile, newModFile, os.ModePerm) + if err != nil { + log.Fatalf("fail to update go.mod: %v", err) + } + b.IsModEdit = true + } + return } diff --git a/pkg/client/agent.go b/pkg/client/agent.go index 015e1bd..d23bbf7 100644 --- a/pkg/client/agent.go +++ b/pkg/client/agent.go @@ -14,83 +14,83 @@ package client import ( - "encoding/json" - "fmt" - "os" + "encoding/json" + "fmt" + "os" - "github.com/RickLeee/goc/v2/pkg/client/rest" - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/olekukonko/tablewriter" + "github.com/ar0c/goc/v2/pkg/client/rest" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/olekukonko/tablewriter" ) const ( - DISCONNECT = 1 << iota - RPCCONNECT = 1 << iota - WATCHCONNECT = 1 << iota + DISCONNECT = 1 << iota + RPCCONNECT = 1 << iota + WATCHCONNECT = 1 << iota ) func ListAgents(host string, ids []string, wide, isJson bool) { - gocClient := rest.NewV2Client(host) + gocClient := rest.NewV2Client(host) - agents, err := gocClient.Agent().Get(ids) + agents, err := gocClient.Agent().Get(ids) - if err != nil { - log.Fatalf("cannot get agent list from goc server: %v", err) - } - table := tablewriter.NewWriter(os.Stdout) - if isJson { - goto asJson - } + if err != nil { + log.Fatalf("cannot get agent list from goc server: %v", err) + } + table := tablewriter.NewWriter(os.Stdout) + if isJson { + goto asJson + } - table.SetCenterSeparator("") - table.SetColumnSeparator("") - table.SetRowSeparator("") - table.SetHeaderLine(false) - table.SetBorder(false) - table.SetTablePadding(" ") // pad with 3 blank spaces - table.SetNoWhiteSpace(true) - table.SetReflowDuringAutoWrap(false) - table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) - table.SetAutoWrapText(false) - if wide { - table.SetHeader([]string{"ID", "STATUS", "REMOTEIP", "HOSTNAME", "PID", "CMD", "EXTRA"}) - table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT}) - } else { - table.SetHeader([]string{"ID", "STATUS", "REMOTEIP", "CMD"}) - table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT}) - } + table.SetCenterSeparator("") + table.SetColumnSeparator("") + table.SetRowSeparator("") + table.SetHeaderLine(false) + table.SetBorder(false) + table.SetTablePadding(" ") // pad with 3 blank spaces + table.SetNoWhiteSpace(true) + table.SetReflowDuringAutoWrap(false) + table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) + table.SetAutoWrapText(false) + if wide { + table.SetHeader([]string{"ID", "STATUS", "REMOTEIP", "HOSTNAME", "PID", "CMD", "EXTRA"}) + table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT}) + } else { + table.SetHeader([]string{"ID", "STATUS", "REMOTEIP", "CMD"}) + table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT}) + } asJson: - for _, agent := range agents { - var status string - if agent.Status == DISCONNECT { - status = "DISCONNECT" - } else if agent.Status&(RPCCONNECT|WATCHCONNECT) > 0 { - status = "CONNECT" - } - agent.StatusStr = status - if !isJson { - if wide { - table.Append([]string{agent.Id, status, agent.RemoteIP, agent.Hostname, agent.Pid, agent.CmdLine, agent.Extra}) - } else { - preLen := len(agent.Id) + len(agent.RemoteIP) + 9 - table.Append([]string{agent.Id, status, agent.RemoteIP, getSimpleCmdLine(preLen, agent.CmdLine)}) - } - } - } - if !isJson { - table.Render() - } else { - b, _ := json.Marshal(agents) - fmt.Fprint(os.Stdout, string(b)) - } + for _, agent := range agents { + var status string + if agent.Status == DISCONNECT { + status = "DISCONNECT" + } else if agent.Status&(RPCCONNECT|WATCHCONNECT) > 0 { + status = "CONNECT" + } + agent.StatusStr = status + if !isJson { + if wide { + table.Append([]string{agent.Id, status, agent.RemoteIP, agent.Hostname, agent.Pid, agent.CmdLine, agent.Extra}) + } else { + preLen := len(agent.Id) + len(agent.RemoteIP) + 9 + table.Append([]string{agent.Id, status, agent.RemoteIP, getSimpleCmdLine(preLen, agent.CmdLine)}) + } + } + } + if !isJson { + table.Render() + } else { + b, _ := json.Marshal(agents) + fmt.Fprint(os.Stdout, string(b)) + } } func DeleteAgents(host string, ids []string) { - gocClient := rest.NewV2Client(host) + gocClient := rest.NewV2Client(host) - err := gocClient.Agent().Delete(ids) + err := gocClient.Agent().Delete(ids) - if err != nil { - log.Fatalf("cannot delete agents from goc server: %v", err) - } + if err != nil { + log.Fatalf("cannot delete agents from goc server: %v", err) + } } diff --git a/pkg/client/client.go b/pkg/client/client.go index fc00d8b..76326bb 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -14,198 +14,198 @@ package client import ( - "bytes" - "encoding/json" - "fmt" - "io" - "io/ioutil" - "net" - "net/http" - "net/url" - "os" - "path/filepath" + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "os" + "path/filepath" - "golang.org/x/term" + "golang.org/x/term" - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/olekukonko/tablewriter" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/olekukonko/tablewriter" ) // Action provides methods to contact with the covered agent under test type Action interface { - ListAgents(bool) - Profile(string) + ListAgents(bool) + Profile(string) } const ( - // CoverAgentsListAPI list all the registered agents - CoverAgentsListAPI = "/v2/rpcagents" - //CoverProfileAPI is provided by the covered service to get profiles - CoverProfileAPI = "/v2/cover/profile" + // CoverAgentsListAPI list all the registered agents + CoverAgentsListAPI = "/v2/rpcagents" + //CoverProfileAPI is provided by the covered service to get profiles + CoverProfileAPI = "/v2/cover/profile" ) type client struct { - Host string - client *http.Client + Host string + client *http.Client } // gocListAgents response of the list request type gocListAgents struct { - Items []gocCoveredAgent `json:"items"` + Items []gocCoveredAgent `json:"items"` } // gocCoveredAgent represents a covered client type gocCoveredAgent struct { - Id string `json:"id"` - RemoteIP string `json:"remoteip"` - Hostname string `json:"hostname"` - CmdLine string `json:"cmdline"` - Pid string `json:"pid"` + Id string `json:"id"` + RemoteIP string `json:"remoteip"` + Hostname string `json:"hostname"` + CmdLine string `json:"cmdline"` + Pid string `json:"pid"` } type gocProfile struct { - Profile string `json:"profile"` + Profile string `json:"profile"` } // NewWorker creates a worker to contact with host func NewWorker(host string) Action { - _, err := url.ParseRequestURI(host) - if err != nil { - log.Fatalf("parse url %s failed, err: %v", host, err) - } - return &client{ - Host: host, - client: http.DefaultClient, - } + _, err := url.ParseRequestURI(host) + if err != nil { + log.Fatalf("parse url %s failed, err: %v", host, err) + } + return &client{ + Host: host, + client: http.DefaultClient, + } } // ListAgents Deprecated func (c *client) ListAgents(wide bool) { - u := fmt.Sprintf("%s%s", c.Host, CoverAgentsListAPI) - _, body, err := c.do("GET", u, "", nil) - if err != nil && isNetworkError(err) { - _, body, err = c.do("GET", u, "", nil) - } - if err != nil { - log.Fatalf("goc list failed: %v", err) - } - agents := gocListAgents{} - err = json.Unmarshal(body, &agents) - if err != nil { - log.Fatalf("goc list failed: json unmarshal failed: %v", err) - } - table := tablewriter.NewWriter(os.Stdout) - table.SetCenterSeparator("") - table.SetColumnSeparator("") - table.SetRowSeparator("") - table.SetHeaderLine(false) - table.SetBorder(false) - table.SetTablePadding(" ") // pad with 3 blank spaces - table.SetNoWhiteSpace(true) - table.SetReflowDuringAutoWrap(false) - table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) - table.SetAutoWrapText(false) - if wide { - table.SetHeader([]string{"ID", "REMOTEIP", "HOSTNAME", "PID", "CMD"}) - table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT}) - } else { - table.SetHeader([]string{"ID", "REMOTEIP", "CMD"}) - table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT}) - } - for _, agent := range agents.Items { - if wide { - table.Append([]string{agent.Id, agent.RemoteIP, agent.Hostname, agent.Pid, agent.CmdLine}) - } else { - preLen := len(agent.Id) + len(agent.RemoteIP) + 9 - table.Append([]string{agent.Id, agent.RemoteIP, getSimpleCmdLine(preLen, agent.CmdLine)}) - } - } - table.Render() - return + u := fmt.Sprintf("%s%s", c.Host, CoverAgentsListAPI) + _, body, err := c.do("GET", u, "", nil) + if err != nil && isNetworkError(err) { + _, body, err = c.do("GET", u, "", nil) + } + if err != nil { + log.Fatalf("goc list failed: %v", err) + } + agents := gocListAgents{} + err = json.Unmarshal(body, &agents) + if err != nil { + log.Fatalf("goc list failed: json unmarshal failed: %v", err) + } + table := tablewriter.NewWriter(os.Stdout) + table.SetCenterSeparator("") + table.SetColumnSeparator("") + table.SetRowSeparator("") + table.SetHeaderLine(false) + table.SetBorder(false) + table.SetTablePadding(" ") // pad with 3 blank spaces + table.SetNoWhiteSpace(true) + table.SetReflowDuringAutoWrap(false) + table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) + table.SetAutoWrapText(false) + if wide { + table.SetHeader([]string{"ID", "REMOTEIP", "HOSTNAME", "PID", "CMD"}) + table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT}) + } else { + table.SetHeader([]string{"ID", "REMOTEIP", "CMD"}) + table.SetColumnAlignment([]int{tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT, tablewriter.ALIGN_LEFT}) + } + for _, agent := range agents.Items { + if wide { + table.Append([]string{agent.Id, agent.RemoteIP, agent.Hostname, agent.Pid, agent.CmdLine}) + } else { + preLen := len(agent.Id) + len(agent.RemoteIP) + 9 + table.Append([]string{agent.Id, agent.RemoteIP, getSimpleCmdLine(preLen, agent.CmdLine)}) + } + } + table.Render() + return } func (c *client) Profile(output string) { - u := fmt.Sprintf("%s%s", c.Host, CoverProfileAPI) + u := fmt.Sprintf("%s%s", c.Host, CoverProfileAPI) - res, profile, err := c.do("GET", u, "application/json", nil) - if err != nil && isNetworkError(err) { - res, profile, err = c.do("GET", u, "application/json", nil) - } + res, profile, err := c.do("GET", u, "application/json", nil) + if err != nil && isNetworkError(err) { + res, profile, err = c.do("GET", u, "application/json", nil) + } - if err == nil && res.StatusCode != 200 { - log.Fatalf(string(profile)) - } - var profileText gocProfile - err = json.Unmarshal(profile, &profileText) - if err != nil { - log.Fatalf("profile unmarshal failed: %v", err) - } - if output == "" { - fmt.Fprint(os.Stdout, profileText.Profile) - } else { - var dir, filename string = filepath.Split(output) - if dir != "" { - err = os.MkdirAll(dir, os.ModePerm) - if err != nil { - log.Fatalf("failed to create directory %s, err:%v", dir, err) - } - } - if filename == "" { - output += "coverage.cov" - } + if err == nil && res.StatusCode != 200 { + log.Fatalf(string(profile)) + } + var profileText gocProfile + err = json.Unmarshal(profile, &profileText) + if err != nil { + log.Fatalf("profile unmarshal failed: %v", err) + } + if output == "" { + fmt.Fprint(os.Stdout, profileText.Profile) + } else { + var dir, filename string = filepath.Split(output) + if dir != "" { + err = os.MkdirAll(dir, os.ModePerm) + if err != nil { + log.Fatalf("failed to create directory %s, err:%v", dir, err) + } + } + if filename == "" { + output += "coverage.cov" + } - f, err := os.Create(output) - if err != nil { - log.Fatalf("failed to create file %s, err:%v", output, err) - } - defer f.Close() - _, err = io.Copy(f, bytes.NewReader([]byte(profileText.Profile))) - if err != nil { - log.Fatalf("failed to write file: %v, err: %v", output, err) - } - } + f, err := os.Create(output) + if err != nil { + log.Fatalf("failed to create file %s, err:%v", output, err) + } + defer f.Close() + _, err = io.Copy(f, bytes.NewReader([]byte(profileText.Profile))) + if err != nil { + log.Fatalf("failed to write file: %v, err: %v", output, err) + } + } } // getSimpleCmdLine func getSimpleCmdLine(preLen int, cmdLine string) string { - pathLen := len(cmdLine) - width, _, err := term.GetSize(int(os.Stdin.Fd())) - if err != nil || width <= preLen+16 { - width = 16 + preLen // show at least 16 words of the command - } - if pathLen > width-preLen { - return cmdLine[:width-preLen] - } - return cmdLine + pathLen := len(cmdLine) + width, _, err := term.GetSize(int(os.Stdin.Fd())) + if err != nil || width <= preLen+16 { + width = 16 + preLen // show at least 16 words of the command + } + if pathLen > width-preLen { + return cmdLine[:width-preLen] + } + return cmdLine } func (c *client) do(method, url, contentType string, body io.Reader) (*http.Response, []byte, error) { - req, err := http.NewRequest(method, url, body) - if err != nil { - return nil, nil, err - } + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, nil, err + } - if contentType != "" { - req.Header.Set("Content-Type", contentType) - } + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } - res, err := c.client.Do(req) - if err != nil { - return nil, nil, err - } - defer res.Body.Close() + res, err := c.client.Do(req) + if err != nil { + return nil, nil, err + } + defer res.Body.Close() - responseBody, err := ioutil.ReadAll(res.Body) - if err != nil { - return res, nil, err - } - return res, responseBody, nil + responseBody, err := ioutil.ReadAll(res.Body) + if err != nil { + return res, nil, err + } + return res, responseBody, nil } func isNetworkError(err error) bool { - if err == io.EOF { - return true - } - _, ok := err.(net.Error) - return ok + if err == io.EOF { + return true + } + _, ok := err.(net.Error) + return ok } diff --git a/pkg/client/profie.go b/pkg/client/profie.go index 443a943..1234587 100644 --- a/pkg/client/profie.go +++ b/pkg/client/profie.go @@ -14,61 +14,61 @@ package client import ( - "bytes" - "fmt" - "io" - "os" - "path/filepath" + "bytes" + "fmt" + "io" + "os" + "path/filepath" - "github.com/RickLeee/goc/v2/pkg/client/rest" - "github.com/RickLeee/goc/v2/pkg/client/rest/profile" - "github.com/RickLeee/goc/v2/pkg/log" + "github.com/ar0c/goc/v2/pkg/client/rest" + "github.com/ar0c/goc/v2/pkg/client/rest/profile" + "github.com/ar0c/goc/v2/pkg/log" ) func GetProfile(host string, ids []string, skips []string, extra string, output string, need []string) { - gocClient := rest.NewV2Client(host) + gocClient := rest.NewV2Client(host) - profiles, err := gocClient.Profile().Get(ids, - profile.WithPackagePattern(skips), - profile.WithExtraPattern(extra), - profile.WithNeed(need)) - if err != nil { - log.Fatalf("fail to get profile from the goc server: %v, response: %v", err, profiles) - } + profiles, err := gocClient.Profile().Get(ids, + profile.WithPackagePattern(skips), + profile.WithExtraPattern(extra), + profile.WithNeed(need)) + if err != nil { + log.Fatalf("fail to get profile from the goc server: %v, response: %v", err, profiles) + } - if output == "" { - fmt.Fprint(os.Stdout, profiles) - } else { - var dir, filename string = filepath.Split(output) - if dir != "" { - err = os.MkdirAll(dir, os.ModePerm) - if err != nil { - log.Fatalf("failed to create directory %s, err:%v", dir, err) - } - } - if filename == "" { - output += "coverage.cov" - } + if output == "" { + fmt.Fprint(os.Stdout, profiles) + } else { + var dir, filename string = filepath.Split(output) + if dir != "" { + err = os.MkdirAll(dir, os.ModePerm) + if err != nil { + log.Fatalf("failed to create directory %s, err:%v", dir, err) + } + } + if filename == "" { + output += "coverage.cov" + } - f, err := os.Create(output) - if err != nil { - log.Fatalf("failed to create file %s, err:%v", output, err) - } - defer f.Close() - _, err = io.Copy(f, bytes.NewReader([]byte(profiles))) - if err != nil { - log.Fatalf("failed to write file: %v, err: %v", output, err) - } - } + f, err := os.Create(output) + if err != nil { + log.Fatalf("failed to create file %s, err:%v", output, err) + } + defer f.Close() + _, err = io.Copy(f, bytes.NewReader([]byte(profiles))) + if err != nil { + log.Fatalf("failed to write file: %v, err: %v", output, err) + } + } } func ClearProfile(host string, ids []string, extra string) { - gocClient := rest.NewV2Client(host) + gocClient := rest.NewV2Client(host) - err := gocClient.Profile().Delete(ids, - profile.WithExtraPattern(extra)) + err := gocClient.Profile().Delete(ids, + profile.WithExtraPattern(extra)) - if err != nil { - log.Fatalf("fail to clear the profile: %v", err) - } + if err != nil { + log.Fatalf("fail to clear the profile: %v", err) + } } diff --git a/pkg/client/rest/client.go b/pkg/client/rest/client.go index 339d2bb..38223a9 100644 --- a/pkg/client/rest/client.go +++ b/pkg/client/rest/client.go @@ -14,26 +14,26 @@ package rest import ( - "github.com/RickLeee/goc/v2/pkg/client/rest/agent" - "github.com/RickLeee/goc/v2/pkg/client/rest/profile" - "github.com/go-resty/resty/v2" + "github.com/ar0c/goc/v2/pkg/client/rest/agent" + "github.com/ar0c/goc/v2/pkg/client/rest/profile" + "github.com/go-resty/resty/v2" ) // V2Client provides methods contact with the covered agent under test type V2Client struct { - rest *resty.Client + rest *resty.Client } func NewV2Client(host string) *V2Client { - return &V2Client{ - rest: resty.New().SetHostURL("http://" + host), - } + return &V2Client{ + rest: resty.New().SetHostURL("http://" + host), + } } func (c *V2Client) Agent() agent.AgentInterface { - return agent.NewAgentsClient(c.rest) + return agent.NewAgentsClient(c.rest) } func (c *V2Client) Profile() profile.ProfileInterface { - return profile.NewProfileClient(c.rest) + return profile.NewProfileClient(c.rest) } diff --git a/pkg/server/api.go b/pkg/server/api.go index 754fd6a..84a446d 100644 --- a/pkg/server/api.go +++ b/pkg/server/api.go @@ -14,390 +14,390 @@ package server import ( - "bytes" - "fmt" - "net/http" - "regexp" - "strings" - "sync" - "time" + "bytes" + "fmt" + "net/http" + "regexp" + "strings" + "sync" + "time" - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/gin-gonic/gin" - "golang.org/x/tools/cover" - "k8s.io/test-infra/gopherage/pkg/cov" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/gin-gonic/gin" + "golang.org/x/tools/cover" + "k8s.io/test-infra/gopherage/pkg/cov" ) // listAgents return all service informations func (gs *gocServer) listAgents(c *gin.Context) { - idQuery := c.Query("id") - ifInIdMap := idMaps(idQuery) + idQuery := c.Query("id") + ifInIdMap := idMaps(idQuery) - agents := make([]*gocCoveredAgent, 0) + agents := make([]*gocCoveredAgent, 0) - gs.agents.Range(func(key, value interface{}) bool { - // check if id is in the query ids - if !ifInIdMap(key.(string)) { - return true - } + gs.agents.Range(func(key, value interface{}) bool { + // check if id is in the query ids + if !ifInIdMap(key.(string)) { + return true + } - agent, ok := value.(*gocCoveredAgent) - if !ok { - return false - } - agents = append(agents, agent) - return true - }) + agent, ok := value.(*gocCoveredAgent) + if !ok { + return false + } + agents = append(agents, agent) + return true + }) - c.JSON(http.StatusOK, gin.H{ - "items": agents, - }) + c.JSON(http.StatusOK, gin.H{ + "items": agents, + }) } func (gs *gocServer) removeAgents(c *gin.Context) { - idQuery := c.Query("id") - ifInIdMap := idMaps(idQuery) + idQuery := c.Query("id") + ifInIdMap := idMaps(idQuery) - errs := "" - gs.agents.Range(func(key, value interface{}) bool { + errs := "" + gs.agents.Range(func(key, value interface{}) bool { - // check if id is in the query ids - id := key.(string) - if !ifInIdMap(id) { - return true - } + // check if id is in the query ids + id := key.(string) + if !ifInIdMap(id) { + return true + } - agent, ok := value.(*gocCoveredAgent) - if !ok { - return false - } + agent, ok := value.(*gocCoveredAgent) + if !ok { + return false + } - err := gs.removeAgentFromStore(id) - if err != nil { - log.Errorf("fail to remove agent: %v", id) - err := fmt.Errorf("fail to remove agent: %v, err: %v", id, err) - errs = errs + err.Error() - return true - } - agent.closeConnection() - gs.agents.Delete(key) + err := gs.removeAgentFromStore(id) + if err != nil { + log.Errorf("fail to remove agent: %v", id) + err := fmt.Errorf("fail to remove agent: %v, err: %v", id, err) + errs = errs + err.Error() + return true + } + agent.closeConnection() + gs.agents.Delete(key) - return true - }) + return true + }) - if errs != "" { - c.JSON(http.StatusInternalServerError, gin.H{ - "msg": errs, - }) - } else { - c.JSON(http.StatusOK, nil) - } + if errs != "" { + c.JSON(http.StatusInternalServerError, gin.H{ + "msg": errs, + }) + } else { + c.JSON(http.StatusOK, nil) + } } // getProfiles get and merge all agents' informations // // it is synchronous func (gs *gocServer) getProfiles(c *gin.Context) { - idQuery := c.Query("id") - ifInIdMap := idMaps(idQuery) + idQuery := c.Query("id") + ifInIdMap := idMaps(idQuery) - skippatternRaw := c.Query("skippattern") - var skippattern []string - if skippatternRaw != "" { - skippattern = strings.Split(skippatternRaw, ",") - } - neerpatternRaw := c.Query("needpattern") - var neerpattern []string - if neerpatternRaw != "" { - neerpattern = strings.Split(neerpatternRaw, ",") - } + skippatternRaw := c.Query("skippattern") + var skippattern []string + if skippatternRaw != "" { + skippattern = strings.Split(skippatternRaw, ",") + } + neerpatternRaw := c.Query("needpattern") + var neerpattern []string + if neerpatternRaw != "" { + neerpattern = strings.Split(neerpatternRaw, ",") + } - extra := c.Query("extra") - isExtra := filterExtra(extra) + extra := c.Query("extra") + isExtra := filterExtra(extra) - var mu sync.Mutex - var wg sync.WaitGroup + var mu sync.Mutex + var wg sync.WaitGroup - mergedProfiles := make([][]*cover.Profile, 0) + mergedProfiles := make([][]*cover.Profile, 0) - gs.agents.Range(func(key, value interface{}) bool { - // check if id is in the query ids - if !ifInIdMap(key.(string)) { - // not in - return true - } + gs.agents.Range(func(key, value interface{}) bool { + // check if id is in the query ids + if !ifInIdMap(key.(string)) { + // not in + return true + } - agent, ok := value.(*gocCoveredAgent) - if !ok { - return false - } + agent, ok := value.(*gocCoveredAgent) + if !ok { + return false + } - // check if extra matches - if !isExtra(agent.Extra) { - // not match - return true - } + // check if extra matches + if !isExtra(agent.Extra) { + // not match + return true + } - wg.Add(1) - // 并发 rpc,且每个 rpc 设超时时间 10 second - go func() { - defer wg.Done() + wg.Add(1) + // 并发 rpc,且每个 rpc 设超时时间 10 second + go func() { + defer wg.Done() - timeout := time.Duration(10 * time.Second) - done := make(chan error, 1) + timeout := time.Duration(10 * time.Second) + done := make(chan error, 1) - var req ProfileReq = "getprofile" - var res ProfileRes - go func() { - // lock-free - rpc := agent.rpc - if rpc == nil || agent.Status == DISCONNECT { - done <- nil - return - } - err := agent.rpc.Call("GocAgent.GetProfile", req, &res) - if err != nil { - log.Errorf("fail to get profile from: %v, reasson: %v. let's close the connection", agent.Id, err) - } - done <- err - }() + var req ProfileReq = "getprofile" + var res ProfileRes + go func() { + // lock-free + rpc := agent.rpc + if rpc == nil || agent.Status == DISCONNECT { + done <- nil + return + } + err := agent.rpc.Call("GocAgent.GetProfile", req, &res) + if err != nil { + log.Errorf("fail to get profile from: %v, reasson: %v. let's close the connection", agent.Id, err) + } + done <- err + }() - select { - // rpc 超时 - case <-time.After(timeout): - log.Warnf("rpc call timeout: %v", agent.Hostname) - // 关闭链接 - agent.closeRpcConnOnce() - case err := <-done: - // 调用 rpc 发生错误 - if err != nil { - // 关闭链接 - agent.closeRpcConnOnce() - } - } - // append profile - profile, err := convertProfile([]byte(res)) - if err != nil { - log.Errorf("fail to convert the received profile from: %v, reasson: %v. let's close the connection", agent.Id, err) - // 关闭链接 - agent.closeRpcConnOnce() - return - } + select { + // rpc 超时 + case <-time.After(timeout): + log.Warnf("rpc call timeout: %v", agent.Hostname) + // 关闭链接 + agent.closeRpcConnOnce() + case err := <-done: + // 调用 rpc 发生错误 + if err != nil { + // 关闭链接 + agent.closeRpcConnOnce() + } + } + // append profile + profile, err := convertProfile([]byte(res)) + if err != nil { + log.Errorf("fail to convert the received profile from: %v, reasson: %v. let's close the connection", agent.Id, err) + // 关闭链接 + agent.closeRpcConnOnce() + return + } - // check if skippattern matches - newProfile := filterProfileByPattern(skippattern, neerpattern, profile) + // check if skippattern matches + newProfile := filterProfileByPattern(skippattern, neerpattern, profile) - mu.Lock() - defer mu.Unlock() - mergedProfiles = append(mergedProfiles, newProfile) - }() + mu.Lock() + defer mu.Unlock() + mergedProfiles = append(mergedProfiles, newProfile) + }() - return true - }) + return true + }) - // 一直等待并发的 rpc 都回应 - wg.Wait() + // 一直等待并发的 rpc 都回应 + wg.Wait() - merged, err := cov.MergeMultipleProfiles(mergedProfiles) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "msg": err.Error(), - }) - return - } + merged, err := cov.MergeMultipleProfiles(mergedProfiles) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "msg": err.Error(), + }) + return + } - var buff bytes.Buffer - err = cov.DumpProfile(merged, &buff) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{ - "msg": err.Error(), - }) - return - } + var buff bytes.Buffer + err = cov.DumpProfile(merged, &buff) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "msg": err.Error(), + }) + return + } - c.JSON(http.StatusOK, gin.H{ - "profile": buff.String(), - }) + c.JSON(http.StatusOK, gin.H{ + "profile": buff.String(), + }) } // resetProfiles reset all profiles in agent // // it is async, the function will return immediately func (gs *gocServer) resetProfiles(c *gin.Context) { - idQuery := c.Query("id") - ifInIdMap := idMaps(idQuery) + idQuery := c.Query("id") + ifInIdMap := idMaps(idQuery) - extra := c.Query("extra") - isExtra := filterExtra(extra) + extra := c.Query("extra") + isExtra := filterExtra(extra) - gs.agents.Range(func(key, value interface{}) bool { + gs.agents.Range(func(key, value interface{}) bool { - // check if id is in the query ids - if !ifInIdMap(key.(string)) { - // not in - return true - } + // check if id is in the query ids + if !ifInIdMap(key.(string)) { + // not in + return true + } - agent, ok := value.(*gocCoveredAgent) - if !ok { - return false - } + agent, ok := value.(*gocCoveredAgent) + if !ok { + return false + } - // check if extra matches - if !isExtra(agent.Extra) { - // not match - return true - } + // check if extra matches + if !isExtra(agent.Extra) { + // not match + return true + } - var req ProfileReq = "resetprofile" - var res ProfileRes - go func() { - // lock-free - rpc := agent.rpc - if rpc == nil || agent.Status == DISCONNECT { - return - } - err := rpc.Call("GocAgent.ResetProfile", req, &res) - if err != nil { - log.Errorf("fail to reset profile from: %v, reasson: %v. let's close the connection", agent.Id, err) - // 关闭链接 - agent.closeRpcConnOnce() - } - }() + var req ProfileReq = "resetprofile" + var res ProfileRes + go func() { + // lock-free + rpc := agent.rpc + if rpc == nil || agent.Status == DISCONNECT { + return + } + err := rpc.Call("GocAgent.ResetProfile", req, &res) + if err != nil { + log.Errorf("fail to reset profile from: %v, reasson: %v. let's close the connection", agent.Id, err) + // 关闭链接 + agent.closeRpcConnOnce() + } + }() - return true - }) + return true + }) } // watchProfileUpdate watch the profile change // // any profile change will be updated on this websocket connection. func (gs *gocServer) watchProfileUpdate(c *gin.Context) { - // upgrade to websocket - ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil) - if err != nil { - log.Errorf("fail to establish websocket connection with watch client: %v", err) - c.JSON(http.StatusInternalServerError, nil) - } + // upgrade to websocket + ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Errorf("fail to establish websocket connection with watch client: %v", err) + c.JSON(http.StatusInternalServerError, nil) + } - log.Infof("watch client connected") + log.Infof("watch client connected") - id := time.Now().String() - gwc := &gocWatchClient{ - ws: ws, - exitCh: make(chan int), - } - gs.watchClients.Store(id, gwc) - // send close msg and close ws connection - defer func() { - gs.watchClients.Delete(id) - ws.Close() - gwc.once.Do(func() { close(gwc.exitCh) }) - log.Infof("watch client disconnected") - }() + id := time.Now().String() + gwc := &gocWatchClient{ + ws: ws, + exitCh: make(chan int), + } + gs.watchClients.Store(id, gwc) + // send close msg and close ws connection + defer func() { + gs.watchClients.Delete(id) + ws.Close() + gwc.once.Do(func() { close(gwc.exitCh) }) + log.Infof("watch client disconnected") + }() - // set pong handler - ws.SetReadDeadline(time.Now().Add(PongWait)) - ws.SetPongHandler(func(string) error { - ws.SetReadDeadline(time.Now().Add(PongWait)) - return nil - }) + // set pong handler + ws.SetReadDeadline(time.Now().Add(PongWait)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(PongWait)) + return nil + }) - // set ping goroutine to ping every PingWait time - go func() { - ticker := time.NewTicker(PingWait) - defer ticker.Stop() + // set ping goroutine to ping every PingWait time + go func() { + ticker := time.NewTicker(PingWait) + defer ticker.Stop() - for range ticker.C { - if err := gs.wsping(ws, PongWait); err != nil { - break - } - } + for range ticker.C { + if err := gs.wsping(ws, PongWait); err != nil { + break + } + } - gwc.once.Do(func() { close(gwc.exitCh) }) - }() + gwc.once.Do(func() { close(gwc.exitCh) }) + }() - <-gwc.exitCh + <-gwc.exitCh } func filterProfileByPattern(skippattern []string, needpattern []string, profiles []*cover.Profile) []*cover.Profile { - var out = make([]*cover.Profile, 0) - var skipOut = make([]*cover.Profile, 0) - if len(skippattern) == 0 && len(needpattern) == 0 { - return profiles - } - if len(skippattern) != 0 { - for _, profile := range profiles { - skip := false - for _, pattern := range skippattern { - if strings.Contains(profile.FileName, pattern) { - skip = true - break - } - } + var out = make([]*cover.Profile, 0) + var skipOut = make([]*cover.Profile, 0) + if len(skippattern) == 0 && len(needpattern) == 0 { + return profiles + } + if len(skippattern) != 0 { + for _, profile := range profiles { + skip := false + for _, pattern := range skippattern { + if strings.Contains(profile.FileName, pattern) { + skip = true + break + } + } - if !skip { - skipOut = append(skipOut, profile) - } - } - } else { - skipOut = profiles - } - log.Infof("skipOut len: %v", len(skipOut)) - if len(needpattern) == 0 { - return skipOut - } + if !skip { + skipOut = append(skipOut, profile) + } + } + } else { + skipOut = profiles + } + log.Infof("skipOut len: %v", len(skipOut)) + if len(needpattern) == 0 { + return skipOut + } - for _, profile := range skipOut { - need := false - for _, pattern := range needpattern { - if strings.Contains(profile.FileName, pattern) { - need = true - break - } - } - if need { - out = append(out, profile) - } - } - log.Infof("need out len: %v", len(out)) + for _, profile := range skipOut { + need := false + for _, pattern := range needpattern { + if strings.Contains(profile.FileName, pattern) { + need = true + break + } + } + if need { + out = append(out, profile) + } + } + log.Infof("need out len: %v", len(out)) - return out + return out } func idMaps(idQuery string) func(key string) bool { - idMap := make(map[string]bool) - if len(strings.TrimSpace(idQuery)) == 0 { - } else { - ids := strings.Split(idQuery, ",") - for _, id := range ids { - idMap[id] = true - } - } + idMap := make(map[string]bool) + if len(strings.TrimSpace(idQuery)) == 0 { + } else { + ids := strings.Split(idQuery, ",") + for _, id := range ids { + idMap[id] = true + } + } - inIdMaps := func(key string) bool { - // if no id in query, then all id agent will be return - if len(idMap) == 0 { - return true - } - // other - _, ok := idMap[key] - if !ok { - return false - } else { - return true - } - } + inIdMaps := func(key string) bool { + // if no id in query, then all id agent will be return + if len(idMap) == 0 { + return true + } + // other + _, ok := idMap[key] + if !ok { + return false + } else { + return true + } + } - return inIdMaps + return inIdMaps } func filterExtra(extraPattern string) func(string) bool { - re := regexp.MustCompile(extraPattern) + re := regexp.MustCompile(extraPattern) - return func(extra string) bool { - return re.Match([]byte(extra)) - } + return func(extra string) bool { + return re.Match([]byte(extra)) + } } diff --git a/pkg/server/rpcstream.go b/pkg/server/rpcstream.go index 0f35189..667303a 100644 --- a/pkg/server/rpcstream.go +++ b/pkg/server/rpcstream.go @@ -14,16 +14,16 @@ package server import ( - "crypto/sha256" - "fmt" - "net/http" - "net/rpc" - "net/rpc/jsonrpc" - "sync" - "time" + "crypto/sha256" + "fmt" + "net/http" + "net/rpc" + "net/rpc/jsonrpc" + "sync" + "time" - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/gin-gonic/gin" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/gin-gonic/gin" ) // serveRpcStream holds connection between goc server and agent. @@ -32,114 +32,114 @@ import ( // // 2. 每个链接的 goc agent 作为 rpc 服务端 func (gs *gocServer) serveRpcStream(c *gin.Context) { - // 检查插桩服务上报的信息 - rpcRemoteIP, _ := c.RemoteIP() - id := c.Query("id") - token := c.Query("token") + // 检查插桩服务上报的信息 + rpcRemoteIP, _ := c.RemoteIP() + id := c.Query("id") + token := c.Query("token") - rawagent, ok := gs.agents.Load(id) - if !ok { - c.JSON(http.StatusBadRequest, gin.H{ - "msg": "agent not registered", - "code": 1, - }) - return - } + rawagent, ok := gs.agents.Load(id) + if !ok { + c.JSON(http.StatusBadRequest, gin.H{ + "msg": "agent not registered", + "code": 1, + }) + return + } - agent := rawagent.(*gocCoveredAgent) - if agent.Token != token { - c.JSON(http.StatusBadRequest, gin.H{ - "msg": "register token not match", - "code": 1, - }) - return - } + agent := rawagent.(*gocCoveredAgent) + if agent.Token != token { + c.JSON(http.StatusBadRequest, gin.H{ + "msg": "register token not match", + "code": 1, + }) + return + } - // 更新 agent 信息 - agent.RpcRemoteIP = rpcRemoteIP.String() - agent.exitCh = make(chan int) - agent.Status &= ^DISCONNECT // 取消 DISCONNECT 的状态 - agent.Status |= RPCCONNECT // 设置为 RPC CONNECT 状态 - // 注册销毁函数 - var once sync.Once - agent.closeRpcConnOnce = func() { - once.Do(func() { - // 为什么只是关闭 channel?其它资源如何释放? - // close channel 后,本 goroutine 会进入到 defer - close(agent.exitCh) - }) - } + // 更新 agent 信息 + agent.RpcRemoteIP = rpcRemoteIP.String() + agent.exitCh = make(chan int) + agent.Status &= ^DISCONNECT // 取消 DISCONNECT 的状态 + agent.Status |= RPCCONNECT // 设置为 RPC CONNECT 状态 + // 注册销毁函数 + var once sync.Once + agent.closeRpcConnOnce = func() { + once.Do(func() { + // 为什么只是关闭 channel?其它资源如何释放? + // close channel 后,本 goroutine 会进入到 defer + close(agent.exitCh) + }) + } - // upgrade to websocket - ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil) - if err != nil { - log.Errorf("fail to establish websocket connection with rpc agent: %v", err) - c.JSON(http.StatusInternalServerError, nil) - } + // upgrade to websocket + ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Errorf("fail to establish websocket connection with rpc agent: %v", err) + c.JSON(http.StatusInternalServerError, nil) + } - // send close msg and close ws connection - defer func() { - deadline := 1 * time.Second - // 发送 close msg - gs.wsclose(ws, deadline) - time.Sleep(deadline) + // send close msg and close ws connection + defer func() { + deadline := 1 * time.Second + // 发送 close msg + gs.wsclose(ws, deadline) + time.Sleep(deadline) - // 取消 RPC CONNECT 状态 - agent.Status &= ^RPCCONNECT - if agent.Status == 0 { - agent.Status = DISCONNECT - } + // 取消 RPC CONNECT 状态 + agent.Status &= ^RPCCONNECT + if agent.Status == 0 { + agent.Status = DISCONNECT + } - ws.Close() - log.Infof("close rpc connection, %v", agent.Hostname) - // reset rpc client - agent.rpc = nil - }() + ws.Close() + log.Infof("close rpc connection, %v", agent.Hostname) + // reset rpc client + agent.rpc = nil + }() - // set pong handler - ws.SetReadDeadline(time.Now().Add(PongWait)) - ws.SetPongHandler(func(string) error { - ws.SetReadDeadline(time.Now().Add(PongWait)) - return nil - }) + // set pong handler + ws.SetReadDeadline(time.Now().Add(PongWait)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(PongWait)) + return nil + }) - // set ping goroutine to ping every PingWait time - go func() { - ticker := time.NewTicker(PingWait) - defer ticker.Stop() + // set ping goroutine to ping every PingWait time + go func() { + ticker := time.NewTicker(PingWait) + defer ticker.Stop() - for range ticker.C { - if err := gs.wsping(ws, PongWait); err != nil { - log.Errorf("rpc ping to %v failed: %v", agent.Hostname, err) - break - } - } + for range ticker.C { + if err := gs.wsping(ws, PongWait); err != nil { + log.Errorf("rpc ping to %v failed: %v", agent.Hostname, err) + break + } + } - agent.closeRpcConnOnce() - }() + agent.closeRpcConnOnce() + }() - log.Infof("one rpc agent established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), agent.CmdLine, agent.Pid, agent.Hostname) - // new rpc agent - // 在这里 websocket server 作为 rpc 的客户端, - // 发送 rpc 请求, - // 由被插桩服务返回 rpc 应答 - rwc := &ReadWriteCloser{ws: ws} - codec := jsonrpc.NewClientCodec(rwc) + log.Infof("one rpc agent established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), agent.CmdLine, agent.Pid, agent.Hostname) + // new rpc agent + // 在这里 websocket server 作为 rpc 的客户端, + // 发送 rpc 请求, + // 由被插桩服务返回 rpc 应答 + rwc := &ReadWriteCloser{ws: ws} + codec := jsonrpc.NewClientCodec(rwc) - agent.rpc = rpc.NewClientWithCodec(codec) + agent.rpc = rpc.NewClientWithCodec(codec) - // wait for exit - <-agent.exitCh + // wait for exit + <-agent.exitCh } // generateAgentId generate id based on agent's meta infomation func (gs *gocServer) generateAgentId(args ...string) gocCliendId { - var path string - for _, arg := range args { - path += arg - } - sum := sha256.Sum256([]byte(path)) - h := fmt.Sprintf("%x", sum[:6]) + var path string + for _, arg := range args { + path += arg + } + sum := sha256.Sum256([]byte(path)) + h := fmt.Sprintf("%x", sum[:6]) - return gocCliendId(h) + return gocCliendId(h) } diff --git a/pkg/server/server.go b/pkg/server/server.go index 07450a4..96aa9cb 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -14,236 +14,236 @@ package server import ( - "crypto/sha256" - "encoding/json" - "fmt" - "math/rand" - "net/http" - "net/rpc" - "strconv" - "sync" - "sync/atomic" - "time" + "crypto/sha256" + "encoding/json" + "fmt" + "math/rand" + "net/http" + "net/rpc" + "strconv" + "sync" + "sync/atomic" + "time" - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/RickLeee/goc/v2/pkg/server/store" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/ar0c/goc/v2/pkg/server/store" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) // gocServer represents a goc server type gocServer struct { - port int - store store.Store + port int + store store.Store - upgrader websocket.Upgrader + upgrader websocket.Upgrader - agents sync.Map + agents sync.Map - watchCh chan []byte - watchClients sync.Map + watchCh chan []byte + watchClients sync.Map - idCount int64 - idL sync.Mutex + idCount int64 + idL sync.Mutex } type gocCliendId string const ( - DISCONNECT = 1 << iota - RPCCONNECT = 1 << iota - WATCHCONNECT = 1 << iota + DISCONNECT = 1 << iota + RPCCONNECT = 1 << iota + WATCHCONNECT = 1 << iota ) // gocCoveredAgent represents a covered client type gocCoveredAgent struct { - Id string `json:"id"` - RpcRemoteIP string `json:"rpc_remoteip"` - WatchRemoteIP string `json:"watch_remoteip"` - Hostname string `json:"hostname"` - CmdLine string `json:"cmdline"` - Pid string `json:"pid"` + Id string `json:"id"` + RpcRemoteIP string `json:"rpc_remoteip"` + WatchRemoteIP string `json:"watch_remoteip"` + Hostname string `json:"hostname"` + CmdLine string `json:"cmdline"` + Pid string `json:"pid"` - // 用户可以选择上报一些定制信息 - // 比如不同 namespace 的 statefulset POD,它们的 hostname/cmdline/pid 都是一样的, - // 这时候将 extra 设置为 namespace 并上报,这个额外的信息在展示时将更友好 - Extra string `json:"extra"` + // 用户可以选择上报一些定制信息 + // 比如不同 namespace 的 statefulset POD,它们的 hostname/cmdline/pid 都是一样的, + // 这时候将 extra 设置为 namespace 并上报,这个额外的信息在展示时将更友好 + Extra string `json:"extra"` - Token string `json:"token"` - Status int `json:"status"` // 表示该 agent 是否处于 connected 状态 + Token string `json:"token"` + Status int `json:"status"` // 表示该 agent 是否处于 connected 状态 - rpc *rpc.Client `json:"-"` + rpc *rpc.Client `json:"-"` - exitCh chan int `json:"-"` - closeRpcConnOnce func() `json:"-"` // close rpc conn 只执行一次 - closeWatchConnOnce func() `json:"-"` // close watch conn 只执行一次 + exitCh chan int `json:"-"` + closeRpcConnOnce func() `json:"-"` // close rpc conn 只执行一次 + closeWatchConnOnce func() `json:"-"` // close watch conn 只执行一次 } func (agent *gocCoveredAgent) closeConnection() { - if agent.closeRpcConnOnce != nil { - agent.closeRpcConnOnce() - } + if agent.closeRpcConnOnce != nil { + agent.closeRpcConnOnce() + } - if agent.closeWatchConnOnce != nil { - agent.closeWatchConnOnce() - } + if agent.closeWatchConnOnce != nil { + agent.closeWatchConnOnce() + } } // api 客户端,不是 agent type gocWatchClient struct { - ws *websocket.Conn - exitCh chan int - once sync.Once + ws *websocket.Conn + exitCh chan int + once sync.Once } func RunGocServerUntilExit(host string, s store.Store) error { - gs := gocServer{ - store: s, - upgrader: websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - HandshakeTimeout: 45 * time.Second, - CheckOrigin: func(r *http.Request) bool { - return true - }, - }, - watchCh: make(chan []byte, 4096), - } + gs := gocServer{ + store: s, + upgrader: websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + HandshakeTimeout: 45 * time.Second, + CheckOrigin: func(r *http.Request) bool { + return true + }, + }, + watchCh: make(chan []byte, 4096), + } - // 从持久化存储上恢复 agent 列表 - gs.restoreAgents() + // 从持久化存储上恢复 agent 列表 + gs.restoreAgents() - r := gin.Default() - v2 := r.Group("/v2") - { - v2.GET("/cover/profile", gs.getProfiles) - v2.DELETE("/cover/profile", gs.resetProfiles) - v2.GET("/agents", gs.listAgents) - v2.DELETE("/agents", gs.removeAgents) + r := gin.Default() + v2 := r.Group("/v2") + { + v2.GET("/cover/profile", gs.getProfiles) + v2.DELETE("/cover/profile", gs.resetProfiles) + v2.GET("/agents", gs.listAgents) + v2.DELETE("/agents", gs.removeAgents) - v2.GET("/cover/ws/watch", gs.watchProfileUpdate) + v2.GET("/cover/ws/watch", gs.watchProfileUpdate) - // internal use only - v2.GET("/internal/register", gs.register) - v2.GET("/internal/ws/rpcstream", gs.serveRpcStream) - v2.GET("/internal/ws/watchstream", gs.serveWatchInternalStream) - } + // internal use only + v2.GET("/internal/register", gs.register) + v2.GET("/internal/ws/rpcstream", gs.serveRpcStream) + v2.GET("/internal/ws/watchstream", gs.serveWatchInternalStream) + } - go gs.watchLoop() - return r.Run(host) + go gs.watchLoop() + return r.Run(host) } func (gs *gocServer) register(c *gin.Context) { - // 检查插桩服务上报的信息 - hostname := c.Query("hostname") - pid := c.Query("pid") - cmdline := c.Query("cmdline") - extra := c.Query("extra") + // 检查插桩服务上报的信息 + hostname := c.Query("hostname") + pid := c.Query("pid") + cmdline := c.Query("cmdline") + extra := c.Query("extra") - if hostname == "" || pid == "" || cmdline == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "msg": "missing some params", - }) - return - } + if hostname == "" || pid == "" || cmdline == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "msg": "missing some params", + }) + return + } - gs.idL.Lock() - gs.idCount++ - globalId := gs.idCount - gs.idL.Unlock() + gs.idL.Lock() + gs.idCount++ + globalId := gs.idCount + gs.idL.Unlock() - genToken := func(i int64) string { - now := time.Now().UnixNano() - random := rand.Int() + genToken := func(i int64) string { + now := time.Now().UnixNano() + random := rand.Int() - raw := fmt.Sprintf("%v-%v-%v", i, random, now) - sum := sha256.Sum256([]byte(raw)) - h := fmt.Sprintf("%x", sum[:16]) + raw := fmt.Sprintf("%v-%v-%v", i, random, now) + sum := sha256.Sum256([]byte(raw)) + h := fmt.Sprintf("%x", sum[:16]) - return h - } + return h + } - token := genToken(globalId) - id := strconv.Itoa(int(globalId)) + token := genToken(globalId) + id := strconv.Itoa(int(globalId)) - agent := &gocCoveredAgent{ - Id: id, - Hostname: hostname, - Pid: pid, - CmdLine: cmdline, - Token: token, - Status: DISCONNECT, - Extra: extra, - } + agent := &gocCoveredAgent{ + Id: id, + Hostname: hostname, + Pid: pid, + CmdLine: cmdline, + Token: token, + Status: DISCONNECT, + Extra: extra, + } - // 持久化 - err := gs.saveAgentToStore(agent) - if err != nil { - log.Errorf("fail to save to store: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{ - "msg": err.Error(), - }) - } - // 维护 agent 连接 - gs.agents.Store(id, agent) + // 持久化 + err := gs.saveAgentToStore(agent) + if err != nil { + log.Errorf("fail to save to store: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{ + "msg": err.Error(), + }) + } + // 维护 agent 连接 + gs.agents.Store(id, agent) - log.Infof("one agent registered, id: %v, cmdline: %v, pid: %v, hostname: %v", id, agent.CmdLine, agent.Pid, agent.Hostname) + log.Infof("one agent registered, id: %v, cmdline: %v, pid: %v, hostname: %v", id, agent.CmdLine, agent.Pid, agent.Hostname) - c.JSON(http.StatusOK, gin.H{ - "id": id, - "token": token, - }) + c.JSON(http.StatusOK, gin.H{ + "id": id, + "token": token, + }) } func (gs *gocServer) saveAgentToStore(agent *gocCoveredAgent) error { - value, err := json.Marshal(agent) - if err != nil { - return err - } - return gs.store.Set("/goc/agents/"+agent.Id, string(value)) + value, err := json.Marshal(agent) + if err != nil { + return err + } + return gs.store.Set("/goc/agents/"+agent.Id, string(value)) } func (gs *gocServer) removeAgentFromStore(id string) error { - return gs.store.Remove("/goc/agents/" + id) + return gs.store.Remove("/goc/agents/" + id) } func (gs *gocServer) removeAllAgentsFromStore() error { - return gs.store.RangeRemove("/goc/agents/") + return gs.store.RangeRemove("/goc/agents/") } func (gs *gocServer) restoreAgents() { - pattern := "/goc/agents/" + pattern := "/goc/agents/" - // ignore err, 这个 err 不需要处理,直接忽略 - rawagents, _ := gs.store.Range(pattern) + // ignore err, 这个 err 不需要处理,直接忽略 + rawagents, _ := gs.store.Range(pattern) - var maxId int - for _, rawagent := range rawagents { - var agent gocCoveredAgent - err := json.Unmarshal([]byte(rawagent), &agent) - if err != nil { - log.Fatalf("fail to unmarshal restore agents: %v", err) - } + var maxId int + for _, rawagent := range rawagents { + var agent gocCoveredAgent + err := json.Unmarshal([]byte(rawagent), &agent) + if err != nil { + log.Fatalf("fail to unmarshal restore agents: %v", err) + } - id, err := strconv.Atoi(agent.Id) - if err != nil { - log.Fatalf("fail to transform id to number: %v", err) - } - if maxId < id { - maxId = id - } + id, err := strconv.Atoi(agent.Id) + if err != nil { + log.Fatalf("fail to transform id to number: %v", err) + } + if maxId < id { + maxId = id + } - gs.agents.Store(agent.Id, &agent) - log.Infof("restore one agent: %v, %v from store", id, agent.RpcRemoteIP) + gs.agents.Store(agent.Id, &agent) + log.Infof("restore one agent: %v, %v from store", id, agent.RpcRemoteIP) - agent.RpcRemoteIP = "" - agent.WatchRemoteIP = "" - agent.Status = DISCONNECT - } + agent.RpcRemoteIP = "" + agent.WatchRemoteIP = "" + agent.Status = DISCONNECT + } - // 更新全局 id - atomic.StoreInt64(&gs.idCount, int64(maxId)) + // 更新全局 id + atomic.StoreInt64(&gs.idCount, int64(maxId)) } diff --git a/pkg/server/watchstream.go b/pkg/server/watchstream.go index d7a7346..a18ffd8 100644 --- a/pkg/server/watchstream.go +++ b/pkg/server/watchstream.go @@ -14,120 +14,120 @@ package server import ( - "net/http" - "sync" - "time" + "net/http" + "sync" + "time" - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" ) func (gs *gocServer) serveWatchInternalStream(c *gin.Context) { - // 检查插桩服务上报的信息 - watchRemoteIP, _ := c.RemoteIP() - id := c.Query("id") - token := c.Query("token") + // 检查插桩服务上报的信息 + watchRemoteIP, _ := c.RemoteIP() + id := c.Query("id") + token := c.Query("token") - rawagent, ok := gs.agents.Load(id) - if !ok { - c.JSON(http.StatusBadRequest, gin.H{ - "msg": "agent not registered", - "code": 1, - }) - return - } + rawagent, ok := gs.agents.Load(id) + if !ok { + c.JSON(http.StatusBadRequest, gin.H{ + "msg": "agent not registered", + "code": 1, + }) + return + } - agent := rawagent.(*gocCoveredAgent) - if agent.Token != token { - c.JSON(http.StatusBadRequest, gin.H{ - "msg": "register token not match", - "code": 1, - }) - return - } + agent := rawagent.(*gocCoveredAgent) + if agent.Token != token { + c.JSON(http.StatusBadRequest, gin.H{ + "msg": "register token not match", + "code": 1, + }) + return + } - // 更新 agent 信息 - agent.WatchRemoteIP = watchRemoteIP.String() - agent.Status &= ^DISCONNECT // 取消 DISCONNECT 的状态 - agent.Status |= WATCHCONNECT // 设置为 RPC CONNECT 状态 - var once sync.Once + // 更新 agent 信息 + agent.WatchRemoteIP = watchRemoteIP.String() + agent.Status &= ^DISCONNECT // 取消 DISCONNECT 的状态 + agent.Status |= WATCHCONNECT // 设置为 RPC CONNECT 状态 + var once sync.Once - // upgrade to websocket - ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil) - if err != nil { - log.Errorf("fail to establish websocket connection with watch agent: %v", err) - c.JSON(http.StatusInternalServerError, nil) - } + // upgrade to websocket + ws, err := gs.upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Errorf("fail to establish websocket connection with watch agent: %v", err) + c.JSON(http.StatusInternalServerError, nil) + } - // 注册销毁函数 - agent.closeWatchConnOnce = func() { - once.Do(func() { - // 关闭 ws 连接后,ws.ReadMessage() 会出错退出 goroutine,进入 defer - ws.Close() - }) - } + // 注册销毁函数 + agent.closeWatchConnOnce = func() { + once.Do(func() { + // 关闭 ws 连接后,ws.ReadMessage() 会出错退出 goroutine,进入 defer + ws.Close() + }) + } - // send close msg and close ws connection - defer func() { - // 取消 WATCH CONNECT 状态 - agent.Status &= ^WATCHCONNECT - if agent.Status == 0 { - agent.Status = DISCONNECT - } + // send close msg and close ws connection + defer func() { + // 取消 WATCH CONNECT 状态 + agent.Status &= ^WATCHCONNECT + if agent.Status == 0 { + agent.Status = DISCONNECT + } - agent.closeWatchConnOnce() + agent.closeWatchConnOnce() - log.Infof("close watch connection, %v", agent.Hostname) - }() + log.Infof("close watch connection, %v", agent.Hostname) + }() - // set pong handler - ws.SetReadDeadline(time.Now().Add(PongWait)) - ws.SetPongHandler(func(string) error { - ws.SetReadDeadline(time.Now().Add(PongWait)) - return nil - }) + // set pong handler + ws.SetReadDeadline(time.Now().Add(PongWait)) + ws.SetPongHandler(func(string) error { + ws.SetReadDeadline(time.Now().Add(PongWait)) + return nil + }) - // set ping goroutine to ping every PingWait time - go func() { - ticker := time.NewTicker(PingWait) - defer ticker.Stop() + // set ping goroutine to ping every PingWait time + go func() { + ticker := time.NewTicker(PingWait) + defer ticker.Stop() - for range ticker.C { - if err := gs.wsping(ws, PongWait); err != nil { - log.Errorf("watch ping to %v failed: %v", agent.Hostname, err) - break - } - } - }() + for range ticker.C { + if err := gs.wsping(ws, PongWait); err != nil { + log.Errorf("watch ping to %v failed: %v", agent.Hostname, err) + break + } + } + }() - log.Infof("one watch agent established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), agent.CmdLine, agent.Pid, agent.Hostname) + log.Infof("one watch agent established, %v, cmdline: %v, pid: %v, hostname: %v", ws.RemoteAddr(), agent.CmdLine, agent.Pid, agent.Hostname) - for { - mt, message, err := ws.ReadMessage() - if err != nil { - log.Errorf("read from %v: %v", agent.Hostname, err) - break - } - if mt == websocket.TextMessage { - gs.watchCh <- message - } - } + for { + mt, message, err := ws.ReadMessage() + if err != nil { + log.Errorf("read from %v: %v", agent.Hostname, err) + break + } + if mt == websocket.TextMessage { + gs.watchCh <- message + } + } } func (gs *gocServer) watchLoop() { - for { - msg := <-gs.watchCh - gs.watchClients.Range(func(key, value interface{}) bool { - // 这里是客户端的 ws 连接,不是 agent ws 连接 - gwc := value.(*gocWatchClient) - err := gwc.ws.WriteMessage(websocket.TextMessage, msg) - if err != nil { - gwc.ws.Close() - gwc.once.Do(func() { close(gwc.exitCh) }) - } + for { + msg := <-gs.watchCh + gs.watchClients.Range(func(key, value interface{}) bool { + // 这里是客户端的 ws 连接,不是 agent ws 连接 + gwc := value.(*gocWatchClient) + err := gwc.ws.WriteMessage(websocket.TextMessage, msg) + if err != nil { + gwc.ws.Close() + gwc.once.Do(func() { close(gwc.exitCh) }) + } - return true - }) - } + return true + }) + } } diff --git a/pkg/watch/watch.go b/pkg/watch/watch.go index 1c5b914..9917caf 100644 --- a/pkg/watch/watch.go +++ b/pkg/watch/watch.go @@ -14,26 +14,26 @@ package watch import ( - "fmt" + "fmt" - "github.com/RickLeee/goc/v2/pkg/log" - "github.com/gorilla/websocket" + "github.com/ar0c/goc/v2/pkg/log" + "github.com/gorilla/websocket" ) func Watch(host string) { - watchUrl := fmt.Sprintf("ws://%v/v2/cover/ws/watch", host) - c, _, err := websocket.DefaultDialer.Dial(watchUrl, nil) - if err != nil { - log.Fatalf("cannot connect to goc server: %v", err) - } - defer c.Close() + watchUrl := fmt.Sprintf("ws://%v/v2/cover/ws/watch", host) + c, _, err := websocket.DefaultDialer.Dial(watchUrl, nil) + if err != nil { + log.Fatalf("cannot connect to goc server: %v", err) + } + defer c.Close() - for { - _, message, err := c.ReadMessage() - if err != nil { - log.Fatalf("cannot read message: %v", err) - } + for { + _, message, err := c.ReadMessage() + if err != nil { + log.Fatalf("cannot read message: %v", err) + } - log.Infof("profile update: %v", string(message)) - } + log.Infof("profile update: %v", string(message)) + } }