90
pkg/common/cryptoutil/tls.go
Normal file
90
pkg/common/cryptoutil/tls.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package cryptoutil
|
||||
|
||||
import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"io/ioutil"
|
||||
"math/big"
|
||||
)
|
||||
|
||||
func InsecureTLSConfig() (*tls.Config, error) {
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
template := x509.Certificate{SerialNumber: big.NewInt(1)}
|
||||
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey, privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: bytes})
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||
|
||||
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
CurvePreferences: []tls.CurveID{tls.X25519},
|
||||
CipherSuites: []uint16{tls.TLS_CHACHA20_POLY1305_SHA256},
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
InsecureSkipVerify: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ClientTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
|
||||
caCert, err := ioutil.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
CurvePreferences: []tls.CurveID{tls.X25519},
|
||||
CipherSuites: []uint16{tls.TLS_CHACHA20_POLY1305_SHA256},
|
||||
Certificates: []tls.Certificate{cert},
|
||||
RootCAs: caCertPool,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ServerTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
|
||||
caCert, err := ioutil.ReadFile(caPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
CurvePreferences: []tls.CurveID{tls.X25519},
|
||||
CipherSuites: []uint16{tls.TLS_CHACHA20_POLY1305_SHA256},
|
||||
Certificates: []tls.Certificate{cert},
|
||||
ClientCAs: caCertPool,
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
}, nil
|
||||
}
|
||||
28
pkg/portal/command.go
Normal file
28
pkg/portal/command.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package portal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/apex/log"
|
||||
|
||||
"github.com/speedrunsh/speedrun/proto/portal"
|
||||
)
|
||||
|
||||
func (s *Server) RunCommand(ctx context.Context, in *portal.CommandRequest) (*portal.CommandResponse, error) {
|
||||
fields := log.Fields{
|
||||
"context": "command",
|
||||
}
|
||||
log := log.WithFields(fields)
|
||||
|
||||
log.Debugf("Received command: %s %s", in.GetName(), in.GetArgs())
|
||||
cmd := exec.Command(in.GetName(), in.GetArgs()...)
|
||||
stdout, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
return &portal.CommandResponse{Message: strings.TrimSpace(string(stdout))}, nil
|
||||
}
|
||||
7
pkg/portal/server.go
Normal file
7
pkg/portal/server.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package portal
|
||||
|
||||
import "github.com/speedrunsh/speedrun/proto/portal"
|
||||
|
||||
type Server struct {
|
||||
portal.DRPCPortalUnimplementedServer
|
||||
}
|
||||
158
pkg/portal/service.go
Normal file
158
pkg/portal/service.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package portal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/apex/log"
|
||||
"github.com/coreos/go-systemd/v22/dbus"
|
||||
"github.com/speedrunsh/speedrun/proto/portal"
|
||||
)
|
||||
|
||||
func (s *Server) ServiceRestart(ctx context.Context, service *portal.ServiceRequest) (*portal.ServiceResponse, error) {
|
||||
fields := log.Fields{
|
||||
"context": "service",
|
||||
"command": "restart",
|
||||
"name": service.GetName(),
|
||||
}
|
||||
log := log.WithFields(fields)
|
||||
log.Debug("Received service restart request")
|
||||
|
||||
conn, err := dbus.NewWithContext(ctx)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
responseChan := make(chan string, 1)
|
||||
serviceName := fmt.Sprintf("%s.service", service.GetName())
|
||||
_, err = conn.RestartUnitContext(ctx, serviceName, "replace", responseChan)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := <-responseChan
|
||||
log.Debugf("Service restart result: %v", res)
|
||||
return &portal.ServiceResponse{State: portal.State_CHANGED, Message: strings.Title(res)}, nil
|
||||
}
|
||||
|
||||
func (s *Server) ServiceStop(ctx context.Context, service *portal.ServiceRequest) (*portal.ServiceResponse, error) {
|
||||
fields := log.Fields{
|
||||
"context": "service",
|
||||
"command": "stop",
|
||||
"name": service.GetName(),
|
||||
}
|
||||
log := log.WithFields(fields)
|
||||
log.Debug("Received service stop request")
|
||||
|
||||
conn, err := dbus.NewWithContext(ctx)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
responseChan := make(chan string, 1)
|
||||
serviceName := fmt.Sprintf("%s.service", service.GetName())
|
||||
list, err := conn.ListUnitsByNamesContext(ctx, []string{serviceName})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("Fetched service list by name: %v", list)
|
||||
if list[0].ActiveState == "inactive" {
|
||||
return &portal.ServiceResponse{State: portal.State_UNCHANGED, Message: "Service already stopped"}, nil
|
||||
}
|
||||
|
||||
_, err = conn.StopUnitContext(ctx, serviceName, "replace", responseChan)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := <-responseChan
|
||||
log.Debugf("Service stop result: %v", res)
|
||||
return &portal.ServiceResponse{State: portal.State_CHANGED, Message: strings.Title(res)}, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) ServiceStart(ctx context.Context, service *portal.ServiceRequest) (*portal.ServiceResponse, error) {
|
||||
fields := log.Fields{
|
||||
"context": "service",
|
||||
"command": "start",
|
||||
"name": service.GetName(),
|
||||
}
|
||||
log := log.WithFields(fields)
|
||||
log.Debug("Received service start request")
|
||||
|
||||
conn, err := dbus.NewWithContext(ctx)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
responseChan := make(chan string, 1)
|
||||
serviceName := fmt.Sprintf("%s.service", service.GetName())
|
||||
list, err := conn.ListUnitsByNamesContext(ctx, []string{serviceName})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("Fetched service list by name: %v", list)
|
||||
if list[0].ActiveState == "active" {
|
||||
return &portal.ServiceResponse{State: portal.State_UNCHANGED, Message: "Service already running"}, nil
|
||||
}
|
||||
|
||||
_, err = conn.StartUnitContext(ctx, serviceName, "replace", responseChan)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := <-responseChan
|
||||
log.Debugf("Service start result: %v", res)
|
||||
return &portal.ServiceResponse{State: portal.State_CHANGED, Message: strings.Title(res)}, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) ServiceStatus(ctx context.Context, service *portal.ServiceRequest) (*portal.ServiceStatusResponse, error) {
|
||||
fields := log.Fields{
|
||||
"context": "service",
|
||||
"command": "status",
|
||||
"name": service.GetName(),
|
||||
}
|
||||
log := log.WithFields(fields)
|
||||
log.Debug("Received service status request")
|
||||
|
||||
conn, err := dbus.NewWithContext(ctx)
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
serviceName := fmt.Sprintf("%s.service", service.GetName())
|
||||
res, err := conn.ListUnitsByNamesContext(ctx, []string{serviceName})
|
||||
if err != nil {
|
||||
log.Error(err.Error())
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("Fetched service list by name: %v", res[0])
|
||||
|
||||
if res[0].LoadState == "not-found" {
|
||||
log.Error("service not found")
|
||||
return nil, fmt.Errorf("service not found")
|
||||
}
|
||||
|
||||
return &portal.ServiceStatusResponse{
|
||||
State: portal.State_UNCHANGED,
|
||||
Activestate: res[0].ActiveState,
|
||||
Loadstate: res[0].LoadState,
|
||||
Substate: res[0].SubState,
|
||||
}, nil
|
||||
|
||||
}
|
||||
53
pkg/speedrun/cloud/google.go
Normal file
53
pkg/speedrun/cloud/google.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/api/compute/v1"
|
||||
)
|
||||
|
||||
type GoogleClient struct {
|
||||
*compute.Service
|
||||
}
|
||||
|
||||
func NewGCPClient() (*GoogleClient, error) {
|
||||
var err error
|
||||
ctx := context.Background()
|
||||
|
||||
gce, err := compute.NewService(ctx)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("couldn't initialize GCP client: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GoogleClient{gce}, nil
|
||||
}
|
||||
|
||||
// GetInstances returns a list of external IP addresses used for the SHH connection
|
||||
func (c *GoogleClient) GetInstances(project string) ([]Instance, error) {
|
||||
instances := []Instance{}
|
||||
listCall := c.Instances.AggregatedList(project).Fields("nextPageToken", "items(Name,NetworkInterfaces,Labels)")
|
||||
var ctx context.Context
|
||||
|
||||
listCall.Pages(ctx, func(list *compute.InstanceAggregatedList) error {
|
||||
for _, item := range list.Items {
|
||||
for _, instance := range item.Instances {
|
||||
i := Instance{
|
||||
Name: instance.Name,
|
||||
PrivateAddress: instance.NetworkInterfaces[0].NetworkIP,
|
||||
PublicAddress: instance.NetworkInterfaces[0].AccessConfigs[0].NatIP,
|
||||
Labels: instance.Labels,
|
||||
}
|
||||
instances = append(instances, i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
_, err := listCall.Do()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
93
pkg/speedrun/cloud/instance.go
Normal file
93
pkg/speedrun/cloud/instance.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package cloud
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/antonmedv/expr"
|
||||
"github.com/apex/log"
|
||||
"github.com/speedrunsh/speedrun/pkg/common/cryptoutil"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
PublicAddress string
|
||||
PrivateAddress string
|
||||
Name string
|
||||
Labels map[string]string
|
||||
}
|
||||
|
||||
func (i Instance) GetAddress(private bool) string {
|
||||
if private {
|
||||
return i.PrivateAddress
|
||||
}
|
||||
|
||||
return i.PublicAddress
|
||||
}
|
||||
|
||||
func GetInstances(target string) ([]Instance, error) {
|
||||
project := viper.GetString("gcp.projectid")
|
||||
|
||||
gcpClient, err := NewGCPClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Info("Fetching instance list")
|
||||
instances, err := gcpClient.GetInstances(project)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
subset, err := filter(instances, target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(subset) == 0 {
|
||||
return nil, fmt.Errorf("no instances found")
|
||||
}
|
||||
|
||||
return subset, nil
|
||||
}
|
||||
|
||||
func SetupTLS() (*tls.Config, error) {
|
||||
insecure := viper.GetBool("tls.insecure")
|
||||
caPath := viper.GetString("tls.ca")
|
||||
certPath := viper.GetString("tls.cert")
|
||||
keyPath := viper.GetString("tls.key")
|
||||
|
||||
if insecure {
|
||||
log.Warn("Using insecure TLS configuration, this should be avoided in production environments")
|
||||
return cryptoutil.InsecureTLSConfig()
|
||||
} else {
|
||||
return cryptoutil.ClientTLSConfig(caPath, certPath, keyPath)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func filter(instnces []Instance, target string) ([]Instance, error) {
|
||||
var subset []Instance
|
||||
|
||||
program, err := expr.Compile(target, expr.Env(Instance{}), expr.AsBool())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, instance := range instnces {
|
||||
output, err := expr.Run(program, instance)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
match, ok := output.(bool)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if match {
|
||||
subset = append(subset, instance)
|
||||
}
|
||||
}
|
||||
return subset, nil
|
||||
}
|
||||
Reference in New Issue
Block a user