Created
September 8, 2022 13:28
-
-
Save jbfarez/16014f89207201da3fdff4105e9036e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package simple-updater | |
import ( | |
"fmt" | |
"io" | |
"io/ioutil" | |
"os" | |
"runtime" | |
"strconv" | |
"strings" | |
"syscall" | |
"time" | |
"github.com/aws/aws-sdk-go/aws" | |
"github.com/aws/aws-sdk-go/aws/credentials" | |
"github.com/aws/aws-sdk-go/aws/session" | |
"github.com/aws/aws-sdk-go/service/s3" | |
"github.com/mitchellh/ioprogress" | |
) | |
type Updater struct { | |
CurrentVersion string | |
S3Bucket string | |
S3Region string | |
S3ReleaseKey string | |
S3VersionKey string | |
AWSCredentials *credentials.Credentials | |
} | |
// validate ensures every required fields is correctly set. Otherwise and error is returned. | |
func (u Updater) validate() error { | |
if u.CurrentVersion == "" { | |
return fmt.Errorf("no version set") | |
} | |
if u.S3Bucket == "" { | |
return fmt.Errorf("no bucket set") | |
} | |
if u.S3Region == "" { | |
return fmt.Errorf("no s3 region") | |
} | |
if u.S3ReleaseKey == "" { | |
return fmt.Errorf("no s3ReleaseKey set") | |
} | |
if u.S3VersionKey == "" { | |
return fmt.Errorf("no s3VersionKey set") | |
} | |
return nil | |
} | |
// AutoUpdate runs synchronously a verification to ensure the binary is up-to-date. | |
// If a new version gets released, the download will happen automatically | |
// It's possible to bypass this mechanism by setting the SIMPLE_UPDATER_DISABLED environment variable. | |
func AutoUpdate(u Updater) error { | |
if os.Getenv("SIMPLE_UPDATER_DISABLED") != "" { | |
fmt.Println("simple-updater: autoupdate disabled") | |
return nil | |
} | |
if err := u.validate(); err != nil { | |
fmt.Printf("simple-updater: %s - skipping auto update\n", err.Error()) | |
return err | |
} | |
return runAutoUpdate(u) | |
} | |
// generateS3ReleaseKey dynamically builds the S3 key depending on the os and architecture. | |
func generateS3ReleaseKey(path string) string { | |
path = strings.Replace(path, "{{OS}}", runtime.GOOS, -1) | |
path = strings.Replace(path, "{{ARCH}}", runtime.GOARCH, -1) | |
return path | |
} | |
func runAutoUpdate(u Updater) error { | |
localVersion, err := strconv.ParseInt(u.CurrentVersion, 10, 64) | |
if err != nil || localVersion == 0 { | |
return fmt.Errorf("invalid local version") | |
} | |
svc := s3.New(session.New(), &aws.Config{ | |
Region: aws.String(u.S3Region), | |
Credentials: u.AWSCredentials, | |
}) | |
resp, err := svc.GetObject(&s3.GetObjectInput{Bucket: aws.String(u.S3Bucket), Key: aws.String(u.S3VersionKey)}) | |
if err != nil { | |
return err | |
} | |
defer resp.Body.Close() | |
b, err := ioutil.ReadAll(resp.Body) | |
if err != nil { | |
return err | |
} | |
remoteVersion, err := strconv.ParseInt(string(b), 10, 64) | |
if err != nil || remoteVersion == 0 { | |
return fmt.Errorf("invalid remote version") | |
} | |
fmt.Printf("simple-updater: Local Version %d - Remote Version: %d\n", localVersion, remoteVersion) | |
if localVersion < remoteVersion { | |
fmt.Printf("simple-updater: version outdated ... \n") | |
s3Key := generateS3ReleaseKey(u.S3ReleaseKey) | |
resp, err := svc.GetObject(&s3.GetObjectInput{Bucket: aws.String(u.S3Bucket), Key: aws.String(s3Key)}) | |
if err != nil { | |
return err | |
} | |
defer resp.Body.Close() | |
progressR := &ioprogress.Reader{ | |
Reader: resp.Body, | |
Size: *resp.ContentLength, | |
DrawInterval: 500 * time.Millisecond, | |
DrawFunc: ioprogress.DrawTerminalf(os.Stdout, func(progress, total int64) string { | |
bar := ioprogress.DrawTextFormatBar(40) | |
return fmt.Sprintf("%s %20s", bar(progress, total), ioprogress.DrawTextFormatBytes(progress, total)) | |
}), | |
} | |
dest, err := os.Executable() | |
if err != nil { | |
return err | |
} | |
// Move the old version to a backup path that we can recover from | |
// in case the upgrade fails | |
destBackup := dest + ".bak" | |
if _, err := os.Stat(dest); err == nil { | |
os.Rename(dest, destBackup) | |
} | |
// Use the same flags that ioutil.WriteFile uses | |
f, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755) | |
if err != nil { | |
os.Rename(destBackup, dest) | |
return err | |
} | |
defer f.Close() | |
fmt.Printf("simple-updater: downloading new version to %s\n", dest) | |
if _, err := io.Copy(f, progressR); err != nil { | |
os.Rename(destBackup, dest) | |
return err | |
} | |
// The file must be closed already so we can execute it in the next step | |
f.Close() | |
// Removing backup | |
os.Remove(destBackup) | |
fmt.Printf("simple-updater: updated with success to version %d\nRestarting application\n", remoteVersion) | |
// The update completed, we can now restart the application without requiring any user action. | |
if err := syscall.Exec(dest, os.Args, os.Environ()); err != nil { | |
return err | |
} | |
os.Exit(0) | |
} | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment