package main import ( "context" "crypto/tls" "fmt" "io" "net/http" "os" "regexp" "runtime" "sort" "time" "github.com/Masterminds/semver/v3" "github.com/apex/log" "github.com/fynelabs/selfupdate" ) var reVersion = regexp.MustCompile(`(?m)`) // Pulled from main.go type Flags struct { // Wrap global flags and configurations to add sub-command logic. *models.Flags Login CommandLogin `command:"login" description:"login to example"` Update CommandUpdate `command:"update" description:"update example-cli"` Get struct { // TRUNCATED } `command:"get" description:"get information about a resource"` } type CommandUpdate struct { NexusURL string `long:"nexus-url" description:"URL to use for nexus" default:"https://nexus.example.com"` NexusRepository string `long:"nexus-repository" description:"Nexus repository to use" default:"some-repository-name"` TLSSkipVerify bool `long:"tls-skip-verify" description:"Skip TLS verification"` AllowPrerelease bool `long:"allow-prerelease" description:"Allow updating to prerelease versions"` client *http.Client } func (c *CommandUpdate) Execute(args []string) error { c.initClient() newVersion := c.needsUpdate(context.Background(), true) if newVersion == nil { logger.Info("example-cli is up to date") os.Exit(0) } binary := fmt.Sprintf("example_%s_%s", runtime.GOOS, runtime.GOARCH) if runtime.GOOS == "windows" { binary += ".exe" } uri := fmt.Sprintf( "%s/repository/%s/%s/%s/%s/%s", c.NexusURL, c.NexusRepository, productName, applicationName, newVersion.Original(), binary, ) logger.WithField("uri", uri).Info("downloading update") resp, err := c.client.Get(uri) if err != nil { logger.WithError(err).Fatal("error downloading update") } err = selfupdate.Apply(resp.Body, selfupdate.Options{}) resp.Body.Close() if err != nil { logger.WithError(err).Error("error applying update") if rerr := selfupdate.RollbackError(err); rerr != nil { logger.WithError(rerr).Error("error rolling back update") } os.Exit(1) } logger.WithField("new-version", newVersion.String()).Info("example-cli updated successfully") return nil } func (c *CommandUpdate) initClient() { if c.client != nil { return } c.client = &http.Client{ Timeout: time.Second * 10, Transport: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: c.TLSSkipVerify, }, }, } } // needsUpdate checks if an update is available. If required is set, and an update // check can't be done, this will exit the program with an error. If newVersion is // nil, no update is available. func (c *CommandUpdate) needsUpdate(ctx context.Context, required bool) *semver.Version { c.initClient() logger.WithField("current-version", version).Debug("checking for updates") // Fetch versions. current, versions, err := c.fetchVersions(ctx) if err != nil { if !required { logger.WithError(err).Debug("error fetching versions") return nil } logger.WithError(err).Fatal("error fetching versions") } // Check if we're up to date. if !current.LessThan(versions[0]) { return nil } return versions[0] } func (c *CommandUpdate) fetchVersions(ctx context.Context) (current *semver.Version, versions []*semver.Version, err error) { current, err = semver.NewVersion(version) if err != nil { return nil, nil, fmt.Errorf("error parsing current version %q: %w", version, err) } uri := fmt.Sprintf( "%s/service/rest/repository/browse/%s/%s/%s/", c.NexusURL, c.NexusRepository, productName, applicationName, ) logger.WithField("uri", uri).Debug("fetching versions") req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, http.NoBody) if err != nil { return current, nil, fmt.Errorf("error creating version check request: %w", err) } req.Header.Set("Accept", "text/html,application/xhtml+xml,application/xml") req.Header.Set("Cache-Control", "no-cache") req.Header.Set("User-Agent", fmt.Sprintf("%s/%s/%s", productName, applicationName, version)) resp, err := c.client.Do(req) if err != nil { return current, nil, fmt.Errorf("error performing version check request: %w", err) } defer resp.Body.Close() data, err := io.ReadAll(resp.Body) if err != nil { return current, nil, fmt.Errorf("error parsing version response: %w", err) } logger.WithFields(log.Fields{ "status": resp.Status, "body": string(data), }).Debug("version check response") if resp.StatusCode != http.StatusOK { return current, nil, fmt.Errorf("error performing version check request: %s", resp.Status) } for _, result := range reVersion.FindAllStringSubmatch(string(data), -1) { v, err := semver.NewVersion(result[1]) if err != nil { logger.WithError(err).WithField("version", result[1]).Warn("error parsing version, skipping") continue } if !c.AllowPrerelease && v.Prerelease() != "" { logger.WithField("version", result[1]).Debug("skipping prerelease version") continue } logger.WithField("version", v.String()).Debug("found version") versions = append(versions, v) } if len(versions) == 0 { return current, nil, fmt.Errorf("no versions found") } sort.Sort(sort.Reverse(semver.Collection(versions))) return current, versions, nil }