The following pull request was submitted through Github. It can be accessed and reviewed at: https://github.com/lxc/lxd/pull/6240
This e-mail was sent by the LXC bot, direct replies will not reach the author unless they happen to be subscribed to this list. === Description (from pull-request) === This shows the basic structure of the LXD VM agent. The API looks similar to the one in LXD itself which should make it easier to modify once common code is moved to importable packages (not main). Both `operation.go` and `response.go` are (almost) identical to the ones in the LXD main package. Once #6237 is merged, we should be able to get rid of those. This closes #6234.
From ae05812e07b4b3117affd4e162039c83f8058726 Mon Sep 17 00:00:00 2001 From: Thomas Hipp <thomas.h...@canonical.com> Date: Wed, 25 Sep 2019 16:19:42 +0200 Subject: [PATCH 1/3] lxd: Move IsJSONRequest to util package This moves IsJSONRequest to the util package in order for it to be used by the lxd-agent. Signed-off-by: Thomas Hipp <thomas.h...@canonical.com> --- lxd/daemon.go | 13 +------------ lxd/util/http.go | 11 +++++++++++ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/lxd/daemon.go b/lxd/daemon.go index 2d2d433ee5..4ba13980a3 100644 --- a/lxd/daemon.go +++ b/lxd/daemon.go @@ -318,17 +318,6 @@ func writeMacaroonsRequiredResponse(b *identchecker.Bakery, r *http.Request, w h return } -func isJSONRequest(r *http.Request) bool { - for k, vs := range r.Header { - if strings.ToLower(k) == "content-type" && - len(vs) == 1 && strings.ToLower(vs[0]) == "application/json" { - return true - } - } - - return false -} - // State creates a new State instance linked to our internal db and os. func (d *Daemon) State() *state.State { return state.NewState(d.db, d.cluster, d.maas, d.os, d.endpoints) @@ -403,7 +392,7 @@ func (d *Daemon) createCmd(restAPI *mux.Router, version string, c APIEndpoint) { } // Dump full request JSON when in debug mode - if debug && r.Method != "GET" && isJSONRequest(r) { + if debug && r.Method != "GET" && util.IsJSONRequest(r) { newBody := &bytes.Buffer{} captured := &bytes.Buffer{} multiW := io.MultiWriter(newBody, captured) diff --git a/lxd/util/http.go b/lxd/util/http.go index 96674dd60a..6e65b6779b 100644 --- a/lxd/util/http.go +++ b/lxd/util/http.go @@ -269,3 +269,14 @@ func GetListeners(start int) []net.Listener { // stdout and stderr), so this constant should always be the value passed to // GetListeners, except for unit tests. const SystemdListenFDsStart = 3 + +func IsJSONRequest(r *http.Request) bool { + for k, vs := range r.Header { + if strings.ToLower(k) == "content-type" && + len(vs) == 1 && strings.ToLower(vs[0]) == "application/json" { + return true + } + } + + return false +} From caed22079fcf9e67648dcc2b0691b7e149d68854 Mon Sep 17 00:00:00 2001 From: Thomas Hipp <thomas.h...@canonical.com> Date: Wed, 25 Sep 2019 11:24:12 +0200 Subject: [PATCH 2/3] client: Add vsock support This allows us to reuse client functions when communicating with the vm agent inside of the VM. Signed-off-by: Thomas Hipp <thomas.h...@canonical.com> --- client/connection.go | 37 +++++++++++++++++++++++++++++++++++++ client/util.go | 16 +++++++++++++++- 2 files changed, 52 insertions(+), 1 deletion(-) diff --git a/client/connection.go b/client/connection.go index 9be6e4c3a1..43a6355026 100644 --- a/client/connection.go +++ b/client/connection.go @@ -67,6 +67,43 @@ func ConnectLXD(url string, args *ConnectionArgs) (InstanceServer, error) { return httpsLXD(url, args) } +// ConnectVMAgent lets you connect to a VM agent over a VM socket. +func ConnectVMAgent(vsockID int, args *ConnectionArgs) (InstanceServer, error) { + logger.Debugf("Connecting to a VM agent over a VM socket") + + // Use empty args if not specified + if args == nil { + args = &ConnectionArgs{} + } + + // Initialize the client struct + server := ProtocolLXD{ + httpHost: "http://vm.socket", + httpProtocol: "vsock", + httpUserAgent: args.UserAgent, + } + + // Setup the HTTP client + httpClient, err := vsockHTTPClient(args.HTTPClient, vsockID) + if err != nil { + return nil, err + } + server.http = httpClient + + // Test the connection and seed the server information + if !args.SkipGetServer { + serverStatus, _, err := server.GetServer() + if err != nil { + return nil, err + } + + // Record the server certificate + server.httpCertificate = serverStatus.Environment.Certificate + } + + return &server, nil +} + // ConnectLXDUnix lets you connect to a remote LXD daemon over a local unix socket. // // If the path argument is empty, then $LXD_SOCKET will be used, if diff --git a/client/util.go b/client/util.go index 6e13123e82..60aaec59d2 100644 --- a/client/util.go +++ b/client/util.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/lxc/lxd/shared" + "github.com/mdlayher/vsock" ) func tlsHTTPClient(client *http.Client, tlsClientCert string, tlsClientKey string, tlsCA string, tlsServerCert string, insecureSkipVerify bool, proxy func(req *http.Request) (*url.URL, error)) (*http.Client, error) { @@ -103,6 +104,15 @@ func tlsHTTPClient(client *http.Client, tlsClientCert string, tlsClientKey strin return client, nil } +func vsockHTTPClient(client *http.Client, vsockID int) (*http.Client, error) { + // Setup a VM socket dialer + vsockDial := func(network, addr string) (net.Conn, error) { + return vsock.Dial(uint32(vsockID), 8443) + } + + return socketHTTPClient(vsockDial, client) +} + func unixHTTPClient(client *http.Client, path string) (*http.Client, error) { // Setup a Unix socket dialer unixDial := func(network, addr string) (net.Conn, error) { @@ -114,9 +124,13 @@ func unixHTTPClient(client *http.Client, path string) (*http.Client, error) { return net.DialUnix("unix", nil, raddr) } + return socketHTTPClient(unixDial, client) +} + +func socketHTTPClient(dial func(network, addr string) (net.Conn, error), client *http.Client) (*http.Client, error) { // Define the http transport transport := &http.Transport{ - Dial: unixDial, + Dial: dial, DisableKeepAlives: true, } From f14827696d41953a5734f7998b436290ff523a1a Mon Sep 17 00:00:00 2001 From: Thomas Hipp <thomas.h...@canonical.com> Date: Wed, 25 Sep 2019 20:44:32 +0200 Subject: [PATCH 3/3] lxd-agent: Add basic structure Signed-off-by: Thomas Hipp <thomas.h...@canonical.com> --- lxd-agent/api.go | 25 ++ lxd-agent/api_1.0.go | 57 +++++ lxd-agent/exec.go | 14 ++ lxd-agent/file.go | 16 ++ lxd-agent/main.go | 103 ++++++++ lxd-agent/operations.go | 539 ++++++++++++++++++++++++++++++++++++++++ lxd-agent/response.go | 478 +++++++++++++++++++++++++++++++++++ lxd-agent/state.go | 19 ++ 8 files changed, 1251 insertions(+) create mode 100644 lxd-agent/api.go create mode 100644 lxd-agent/api_1.0.go create mode 100644 lxd-agent/exec.go create mode 100644 lxd-agent/file.go create mode 100644 lxd-agent/main.go create mode 100644 lxd-agent/operations.go create mode 100644 lxd-agent/response.go create mode 100644 lxd-agent/state.go diff --git a/lxd-agent/api.go b/lxd-agent/api.go new file mode 100644 index 0000000000..166c8f0532 --- /dev/null +++ b/lxd-agent/api.go @@ -0,0 +1,25 @@ +package main + +import "net/http" + +// APIEndpoint represents a URL in our API. +type APIEndpoint struct { + Name string // Name for this endpoint. + Path string // Path pattern for this endpoint. + Get APIEndpointAction + Put APIEndpointAction + Post APIEndpointAction + Delete APIEndpointAction + Patch APIEndpointAction +} + +// APIEndpointAlias represents an alias URL of and APIEndpoint in our API. +type APIEndpointAlias struct { + Name string // Name for this alias. + Path string // Path pattern for this alias. +} + +// APIEndpointAction represents an action on an API endpoint. +type APIEndpointAction struct { + Handler func(r *http.Request) Response +} diff --git a/lxd-agent/api_1.0.go b/lxd-agent/api_1.0.go new file mode 100644 index 0000000000..633bc5060f --- /dev/null +++ b/lxd-agent/api_1.0.go @@ -0,0 +1,57 @@ +package main + +import ( + "net/http" + "os" + + "github.com/lxc/lxd/shared" + "github.com/lxc/lxd/shared/api" + "github.com/lxc/lxd/shared/version" +) + +var api10Cmd = APIEndpoint{ + Get: APIEndpointAction{Handler: api10Get}, +} + +var api10 = []APIEndpoint{ + execCmd, + fileCmd, + operationsCmd, + stateCmd, +} + +func api10Get(r *http.Request) Response { + srv := api.ServerUntrusted{ + APIExtensions: version.APIExtensions, // FIXME: use own API extensions + APIStatus: "stable", + APIVersion: version.APIVersion, // FIXME: use own API version + Public: false, + Auth: "trusted", + AuthMethods: []string{"tls"}, + } + + uname, err := shared.Uname() + if err != nil { + return InternalError(err) + } + + serverName, err := os.Hostname() + if err != nil { + return SmartError(err) + } + + env := api.ServerEnvironment{ + Kernel: uname.Sysname, + KernelArchitecture: uname.Machine, + KernelVersion: uname.Release, + Server: "lxd-agent", + ServerPid: os.Getpid(), + ServerVersion: version.Version, + ServerName: serverName, + } + + fullSrv := api.Server{ServerUntrusted: srv} + fullSrv.Environment = env + + return SyncResponseETag(true, fullSrv, fullSrv) +} diff --git a/lxd-agent/exec.go b/lxd-agent/exec.go new file mode 100644 index 0000000000..e3968dce0e --- /dev/null +++ b/lxd-agent/exec.go @@ -0,0 +1,14 @@ +package main + +import "net/http" + +var execCmd = APIEndpoint{ + Name: "exec", + Path: "exec", + + Post: APIEndpointAction{Handler: execPost}, +} + +func execPost(r *http.Request) Response { + return NotImplemented(nil) +} diff --git a/lxd-agent/file.go b/lxd-agent/file.go new file mode 100644 index 0000000000..bb819d00ff --- /dev/null +++ b/lxd-agent/file.go @@ -0,0 +1,16 @@ +package main + +import "net/http" + +var fileCmd = APIEndpoint{ + Name: "file", + Path: "files", + + Get: APIEndpointAction{Handler: fileHandler}, + Post: APIEndpointAction{Handler: fileHandler}, + Delete: APIEndpointAction{Handler: fileHandler}, +} + +func fileHandler(r *http.Request) Response { + return NotImplemented(nil) +} diff --git a/lxd-agent/main.go b/lxd-agent/main.go new file mode 100644 index 0000000000..2d99decd1c --- /dev/null +++ b/lxd-agent/main.go @@ -0,0 +1,103 @@ +package main + +import ( + "bytes" + "fmt" + "io" + "log" + "net/http" + + "github.com/gorilla/mux" + "github.com/lxc/lxd/lxd/util" + "github.com/lxc/lxd/shared" + "github.com/lxc/lxd/shared/logger" +) + +// FIXME: Make this settable +var debug bool + +func main() { + mux := mux.NewRouter() + mux.StrictSlash(false) + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + SyncResponse(true, []string{"/1.0"}).Render(w) + }) + + for _, c := range api10 { + createCmd(mux, "1.0", c) + } + + // FIXME: Use ListenAndServeTLS once we know the location of the cert and keyfile + log.Println(http.ListenAndServe(":8443", mux)) +} + +func createCmd(restAPI *mux.Router, version string, c APIEndpoint) { + var uri string + if c.Path == "" { + uri = fmt.Sprintf("/%s", version) + } else { + uri = fmt.Sprintf("/%s/%s", version, c.Path) + } + + route := restAPI.HandleFunc(uri, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + // Dump full request JSON when in debug mode + if debug && r.Method != "GET" && util.IsJSONRequest(r) { + newBody := &bytes.Buffer{} + captured := &bytes.Buffer{} + multiW := io.MultiWriter(newBody, captured) + if _, err := io.Copy(multiW, r.Body); err != nil { + InternalError(err).Render(w) + return + } + + r.Body = shared.BytesReadCloser{Buf: newBody} + shared.DebugJson(captured) + } + + // Actually process the request + var resp Response + resp = NotImplemented(nil) + + handleRequest := func(action APIEndpointAction) Response { + if action.Handler == nil { + return NotImplemented(nil) + } + + return action.Handler(r) + } + + switch r.Method { + case "GET": + resp = handleRequest(c.Get) + case "PUT": + resp = handleRequest(c.Put) + case "POST": + resp = handleRequest(c.Post) + case "DELETE": + resp = handleRequest(c.Delete) + case "PATCH": + resp = handleRequest(c.Patch) + default: + resp = NotFound(fmt.Errorf("Method '%s' not found", r.Method)) + } + + // Handle errors + err := resp.Render(w) + if err != nil { + err := InternalError(err).Render(w) + if err != nil { + logger.Errorf("Failed writing error for error, giving up") + } + } + }) + + // If the endpoint has a canonical name then record it so it can be used to build URLS + // and accessed in the context of the request by the handler function. + if c.Name != "" { + route.Name(c.Name) + } +} diff --git a/lxd-agent/operations.go b/lxd-agent/operations.go new file mode 100644 index 0000000000..0723067bac --- /dev/null +++ b/lxd-agent/operations.go @@ -0,0 +1,539 @@ +package main + +import ( + "fmt" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/pborman/uuid" + "github.com/pkg/errors" + + "github.com/lxc/lxd/lxd/db" + "github.com/lxc/lxd/lxd/events" + "github.com/lxc/lxd/shared" + "github.com/lxc/lxd/shared/api" + "github.com/lxc/lxd/shared/cancel" + "github.com/lxc/lxd/shared/logger" + "github.com/lxc/lxd/shared/version" +) + +var operationCmd = APIEndpoint{ + Path: "operations/{id}", + + Delete: APIEndpointAction{Handler: operationDelete}, + Get: APIEndpointAction{Handler: operationGet}, +} + +var operationsCmd = APIEndpoint{ + Path: "operations", + + Get: APIEndpointAction{Handler: operationsGet}, +} + +var operationWebsocket = APIEndpoint{ + Path: "operations/{id}/websocket", + + Get: APIEndpointAction{Handler: operationWebsocketGet}, +} + +func operationDelete(r *http.Request) Response { + return NotImplemented(nil) +} + +func operationGet(r *http.Request) Response { + return NotImplemented(nil) +} + +func operationsGet(r *http.Request) Response { + return NotImplemented(nil) +} + +func operationWebsocketGet(r *http.Request) Response { + return NotImplemented(nil) +} + +var operationsLock sync.Mutex +var operations map[string]*operation = make(map[string]*operation) + +type operationClass int + +const ( + operationClassTask operationClass = 1 + operationClassWebsocket operationClass = 2 + operationClassToken operationClass = 3 +) + +func (t operationClass) String() string { + return map[operationClass]string{ + operationClassTask: "task", + operationClassWebsocket: "websocket", + operationClassToken: "token", + }[t] +} + +type operation struct { + project string + id string + class operationClass + createdAt time.Time + updatedAt time.Time + status api.StatusCode + url string + resources map[string][]string + metadata map[string]interface{} + err string + readonly bool + canceler *cancel.Canceler + description string + permission string + + // Those functions are called at various points in the operation lifecycle + onRun func(*operation) error + onCancel func(*operation) error + onConnect func(*operation, *http.Request, http.ResponseWriter) error + + // Channels used for error reporting and state tracking of background actions + chanDone chan error + + // Locking for concurent access to the operation + lock sync.Mutex + + cluster *db.Cluster +} + +func (op *operation) done() { + if op.readonly { + return + } + + op.lock.Lock() + op.readonly = true + op.onRun = nil + op.onCancel = nil + op.onConnect = nil + close(op.chanDone) + op.lock.Unlock() + + time.AfterFunc(time.Second*5, func() { + operationsLock.Lock() + _, ok := operations[op.id] + if !ok { + operationsLock.Unlock() + return + } + + delete(operations, op.id) + operationsLock.Unlock() + + err := op.cluster.Transaction(func(tx *db.ClusterTx) error { + return tx.OperationRemove(op.id) + }) + if err != nil { + logger.Warnf("Failed to delete operation %s: %s", op.id, err) + } + }) +} + +func (op *operation) Run() (chan error, error) { + if op.status != api.Pending { + return nil, fmt.Errorf("Only pending operations can be started") + } + + chanRun := make(chan error, 1) + + op.lock.Lock() + op.status = api.Running + + if op.onRun != nil { + go func(op *operation, chanRun chan error) { + err := op.onRun(op) + if err != nil { + op.lock.Lock() + op.status = api.Failure + op.err = SmartError(err).String() + op.lock.Unlock() + op.done() + chanRun <- err + + logger.Debugf("Failure for %s operation: %s: %s", op.class.String(), op.id, err) + + _, md, _ := op.Render() + events.Send(op.project, "operation", md) + return + } + + op.lock.Lock() + op.status = api.Success + op.lock.Unlock() + op.done() + chanRun <- nil + + op.lock.Lock() + logger.Debugf("Success for %s operation: %s", op.class.String(), op.id) + _, md, _ := op.Render() + events.Send(op.project, "operation", md) + op.lock.Unlock() + }(op, chanRun) + } + op.lock.Unlock() + + logger.Debugf("Started %s operation: %s", op.class.String(), op.id) + _, md, _ := op.Render() + events.Send(op.project, "operation", md) + + return chanRun, nil +} + +func (op *operation) Cancel() (chan error, error) { + if op.status != api.Running { + return nil, fmt.Errorf("Only running operations can be cancelled") + } + + if !op.mayCancel() { + return nil, fmt.Errorf("This operation can't be cancelled") + } + + chanCancel := make(chan error, 1) + + op.lock.Lock() + oldStatus := op.status + op.status = api.Cancelling + op.lock.Unlock() + + if op.onCancel != nil { + go func(op *operation, oldStatus api.StatusCode, chanCancel chan error) { + err := op.onCancel(op) + if err != nil { + op.lock.Lock() + op.status = oldStatus + op.lock.Unlock() + chanCancel <- err + + logger.Debugf("Failed to cancel %s operation: %s: %s", op.class.String(), op.id, err) + _, md, _ := op.Render() + events.Send(op.project, "operation", md) + return + } + + op.lock.Lock() + op.status = api.Cancelled + op.lock.Unlock() + op.done() + chanCancel <- nil + + logger.Debugf("Cancelled %s operation: %s", op.class.String(), op.id) + _, md, _ := op.Render() + events.Send(op.project, "operation", md) + }(op, oldStatus, chanCancel) + } + + logger.Debugf("Cancelling %s operation: %s", op.class.String(), op.id) + _, md, _ := op.Render() + events.Send(op.project, "operation", md) + + if op.canceler != nil { + err := op.canceler.Cancel() + if err != nil { + return nil, err + } + } + + if op.onCancel == nil { + op.lock.Lock() + op.status = api.Cancelled + op.lock.Unlock() + op.done() + chanCancel <- nil + } + + logger.Debugf("Cancelled %s operation: %s", op.class.String(), op.id) + _, md, _ = op.Render() + events.Send(op.project, "operation", md) + + return chanCancel, nil +} + +func (op *operation) Connect(r *http.Request, w http.ResponseWriter) (chan error, error) { + if op.class != operationClassWebsocket { + return nil, fmt.Errorf("Only websocket operations can be connected") + } + + if op.status != api.Running { + return nil, fmt.Errorf("Only running operations can be connected") + } + + chanConnect := make(chan error, 1) + + op.lock.Lock() + + go func(op *operation, chanConnect chan error) { + err := op.onConnect(op, r, w) + if err != nil { + chanConnect <- err + + logger.Debugf("Failed to handle %s operation: %s: %s", op.class.String(), op.id, err) + return + } + + chanConnect <- nil + + logger.Debugf("Handled %s operation: %s", op.class.String(), op.id) + }(op, chanConnect) + op.lock.Unlock() + + logger.Debugf("Connected %s operation: %s", op.class.String(), op.id) + + return chanConnect, nil +} + +func (op *operation) mayCancel() bool { + if op.class == operationClassToken { + return true + } + + if op.onCancel != nil { + return true + } + + if op.canceler != nil && op.canceler.Cancelable() { + return true + } + + return false +} + +func (op *operation) Render() (string, *api.Operation, error) { + // Setup the resource URLs + resources := op.resources + if resources != nil { + tmpResources := make(map[string][]string) + for key, value := range resources { + var values []string + for _, c := range value { + values = append(values, fmt.Sprintf("/%s/%s/%s", version.APIVersion, key, c)) + } + tmpResources[key] = values + } + resources = tmpResources + } + + // Local server name + var err error + var serverName string + err = op.cluster.Transaction(func(tx *db.ClusterTx) error { + serverName, err = tx.NodeName() + return err + }) + if err != nil { + return "", nil, err + } + + return op.url, &api.Operation{ + ID: op.id, + Class: op.class.String(), + Description: op.description, + CreatedAt: op.createdAt, + UpdatedAt: op.updatedAt, + Status: op.status.String(), + StatusCode: op.status, + Resources: resources, + Metadata: op.metadata, + MayCancel: op.mayCancel(), + Err: op.err, + Location: serverName, + }, nil +} + +func (op *operation) WaitFinal(timeout int) (bool, error) { + // Check current state + if op.status.IsFinal() { + return true, nil + } + + // Wait indefinitely + if timeout == -1 { + <-op.chanDone + return true, nil + } + + // Wait until timeout + if timeout > 0 { + timer := time.NewTimer(time.Duration(timeout) * time.Second) + select { + case <-op.chanDone: + return true, nil + + case <-timer.C: + return false, nil + } + } + + return false, nil +} + +func (op *operation) UpdateResources(opResources map[string][]string) error { + if op.status != api.Pending && op.status != api.Running { + return fmt.Errorf("Only pending or running operations can be updated") + } + + if op.readonly { + return fmt.Errorf("Read-only operations can't be updated") + } + + op.lock.Lock() + op.updatedAt = time.Now() + op.resources = opResources + op.lock.Unlock() + + logger.Debugf("Updated resources for %s operation: %s", op.class.String(), op.id) + _, md, _ := op.Render() + events.Send(op.project, "operation", md) + + return nil +} + +func (op *operation) UpdateMetadata(opMetadata interface{}) error { + if op.status != api.Pending && op.status != api.Running { + return fmt.Errorf("Only pending or running operations can be updated") + } + + if op.readonly { + return fmt.Errorf("Read-only operations can't be updated") + } + + newMetadata, err := shared.ParseMetadata(opMetadata) + if err != nil { + return err + } + + op.lock.Lock() + op.updatedAt = time.Now() + op.metadata = newMetadata + op.lock.Unlock() + + logger.Debugf("Updated metadata for %s operation: %s", op.class.String(), op.id) + _, md, _ := op.Render() + events.Send(op.project, "operation", md) + + return nil +} + +func operationCreate(cluster *db.Cluster, project string, opClass operationClass, opType db.OperationType, opResources map[string][]string, opMetadata interface{}, onRun func(*operation) error, onCancel func(*operation) error, onConnect func(*operation, *http.Request, http.ResponseWriter) error) (*operation, error) { + // Main attributes + op := operation{} + op.project = project + op.id = uuid.NewRandom().String() + op.description = opType.Description() + op.permission = opType.Permission() + op.class = opClass + op.createdAt = time.Now() + op.updatedAt = op.createdAt + op.status = api.Pending + op.url = fmt.Sprintf("/%s/operations/%s", version.APIVersion, op.id) + op.resources = opResources + op.chanDone = make(chan error) + op.cluster = cluster + + newMetadata, err := shared.ParseMetadata(opMetadata) + if err != nil { + return nil, err + } + op.metadata = newMetadata + + // Callback functions + op.onRun = onRun + op.onCancel = onCancel + op.onConnect = onConnect + + // Sanity check + if op.class != operationClassWebsocket && op.onConnect != nil { + return nil, fmt.Errorf("Only websocket operations can have a Connect hook") + } + + if op.class == operationClassWebsocket && op.onConnect == nil { + return nil, fmt.Errorf("Websocket operations must have a Connect hook") + } + + if op.class == operationClassToken && op.onRun != nil { + return nil, fmt.Errorf("Token operations can't have a Run hook") + } + + if op.class == operationClassToken && op.onCancel != nil { + return nil, fmt.Errorf("Token operations can't have a Cancel hook") + } + + operationsLock.Lock() + operations[op.id] = &op + operationsLock.Unlock() + + err = op.cluster.Transaction(func(tx *db.ClusterTx) error { + _, err := tx.OperationAdd(project, op.id, opType) + return err + }) + if err != nil { + return nil, errors.Wrapf(err, "failed to add operation %s to database", op.id) + } + + logger.Debugf("New %s operation: %s", op.class.String(), op.id) + _, md, _ := op.Render() + events.Send(op.project, "operation", md) + + return &op, nil +} + +func operationGetInternal(id string) (*operation, error) { + operationsLock.Lock() + op, ok := operations[id] + operationsLock.Unlock() + + if !ok { + return nil, fmt.Errorf("Operation '%s' doesn't exist", id) + } + + return op, nil +} + +type operationWebSocket struct { + req *http.Request + op *operation +} + +func (r *operationWebSocket) Render(w http.ResponseWriter) error { + chanErr, err := r.op.Connect(r.req, w) + if err != nil { + return err + } + + err = <-chanErr + return err +} + +func (r *operationWebSocket) String() string { + _, md, err := r.op.Render() + if err != nil { + return fmt.Sprintf("error: %s", err) + } + + return md.ID +} + +type forwardedOperationWebSocket struct { + req *http.Request + id string + source *websocket.Conn // Connection to the node were the operation is running +} + +func (r *forwardedOperationWebSocket) Render(w http.ResponseWriter) error { + target, err := shared.WebsocketUpgrader.Upgrade(w, r.req, nil) + if err != nil { + return err + } + <-shared.WebsocketProxy(r.source, target) + return nil +} + +func (r *forwardedOperationWebSocket) String() string { + return r.id +} diff --git a/lxd-agent/response.go b/lxd-agent/response.go new file mode 100644 index 0000000000..ebd964b040 --- /dev/null +++ b/lxd-agent/response.go @@ -0,0 +1,478 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "time" + + "github.com/pkg/errors" + + lxd "github.com/lxc/lxd/client" + "github.com/lxc/lxd/lxd/db" + "github.com/lxc/lxd/lxd/util" + "github.com/lxc/lxd/shared" + "github.com/lxc/lxd/shared/api" + "github.com/lxc/lxd/shared/version" +) + +type Response interface { + Render(w http.ResponseWriter) error + String() string +} + +// Sync response +type syncResponse struct { + success bool + etag interface{} + metadata interface{} + location string + code int + headers map[string]string +} + +func (r *syncResponse) Render(w http.ResponseWriter) error { + // Set an appropriate ETag header + if r.etag != nil { + etag, err := util.EtagHash(r.etag) + if err == nil { + w.Header().Set("ETag", etag) + } + } + + // Prepare the JSON response + status := api.Success + if !r.success { + status = api.Failure + } + + if r.headers != nil { + for h, v := range r.headers { + w.Header().Set(h, v) + } + } + + if r.location != "" { + w.Header().Set("Location", r.location) + code := r.code + if code == 0 { + code = 201 + } + w.WriteHeader(code) + } + + resp := api.ResponseRaw{ + Type: api.SyncResponse, + Status: status.String(), + StatusCode: int(status), + Metadata: r.metadata, + } + + return util.WriteJSON(w, resp, debug) +} + +func (r *syncResponse) String() string { + if r.success { + return "success" + } + + return "failure" +} + +func SyncResponse(success bool, metadata interface{}) Response { + return &syncResponse{success: success, metadata: metadata} +} + +func SyncResponseETag(success bool, metadata interface{}, etag interface{}) Response { + return &syncResponse{success: success, metadata: metadata, etag: etag} +} + +func SyncResponseLocation(success bool, metadata interface{}, location string) Response { + return &syncResponse{success: success, metadata: metadata, location: location} +} + +func SyncResponseRedirect(address string) Response { + return &syncResponse{success: true, location: address, code: http.StatusPermanentRedirect} +} + +func SyncResponseHeaders(success bool, metadata interface{}, headers map[string]string) Response { + return &syncResponse{success: success, metadata: metadata, headers: headers} +} + +var EmptySyncResponse = &syncResponse{success: true, metadata: make(map[string]interface{})} + +type forwardedResponse struct { + client lxd.InstanceServer + request *http.Request +} + +func (r *forwardedResponse) Render(w http.ResponseWriter) error { + info, err := r.client.GetConnectionInfo() + if err != nil { + return err + } + + url := fmt.Sprintf("%s%s", info.Addresses[0], r.request.URL.RequestURI()) + forwarded, err := http.NewRequest(r.request.Method, url, r.request.Body) + if err != nil { + return err + } + for key := range r.request.Header { + forwarded.Header.Set(key, r.request.Header.Get(key)) + } + + httpClient, err := r.client.GetHTTPClient() + if err != nil { + return err + } + response, err := httpClient.Do(forwarded) + if err != nil { + return err + } + + for key := range response.Header { + w.Header().Set(key, response.Header.Get(key)) + } + + w.WriteHeader(response.StatusCode) + _, err = io.Copy(w, response.Body) + return err +} + +func (r *forwardedResponse) String() string { + return fmt.Sprintf("request to %s", r.request.URL) +} + +// ForwardedResponse takes a request directed to a node and forwards it to +// another node, writing back the response it gegs. +func ForwardedResponse(client lxd.InstanceServer, request *http.Request) Response { + return &forwardedResponse{ + client: client, + request: request, + } +} + +// File transfer response +type fileResponseEntry struct { + identifier string + path string + filename string + buffer []byte /* either a path or a buffer must be provided */ +} + +type fileResponse struct { + req *http.Request + files []fileResponseEntry + headers map[string]string + removeAfterServe bool +} + +func (r *fileResponse) Render(w http.ResponseWriter) error { + if r.headers != nil { + for k, v := range r.headers { + w.Header().Set(k, v) + } + } + + // No file, well, it's easy then + if len(r.files) == 0 { + return nil + } + + // For a single file, return it inline + if len(r.files) == 1 { + var rs io.ReadSeeker + var mt time.Time + var sz int64 + + if r.files[0].path == "" { + rs = bytes.NewReader(r.files[0].buffer) + mt = time.Now() + sz = int64(len(r.files[0].buffer)) + } else { + f, err := os.Open(r.files[0].path) + if err != nil { + return err + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + return err + } + + mt = fi.ModTime() + sz = fi.Size() + rs = f + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Length", fmt.Sprintf("%d", sz)) + w.Header().Set("Content-Disposition", fmt.Sprintf("inline;filename=%s", r.files[0].filename)) + + http.ServeContent(w, r.req, r.files[0].filename, mt, rs) + if r.files[0].path != "" && r.removeAfterServe { + err := os.Remove(r.files[0].path) + if err != nil { + return err + } + } + + return nil + } + + // Now the complex multipart answer + body := &bytes.Buffer{} + mw := multipart.NewWriter(body) + + for _, entry := range r.files { + var rd io.Reader + if entry.path != "" { + fd, err := os.Open(entry.path) + if err != nil { + return err + } + defer fd.Close() + rd = fd + } else { + rd = bytes.NewReader(entry.buffer) + } + + fw, err := mw.CreateFormFile(entry.identifier, entry.filename) + if err != nil { + return err + } + + _, err = io.Copy(fw, rd) + if err != nil { + return err + } + } + mw.Close() + + w.Header().Set("Content-Type", mw.FormDataContentType()) + w.Header().Set("Content-Length", fmt.Sprintf("%d", body.Len())) + + _, err := io.Copy(w, body) + return err +} + +func (r *fileResponse) String() string { + return fmt.Sprintf("%d files", len(r.files)) +} + +func FileResponse(r *http.Request, files []fileResponseEntry, headers map[string]string, removeAfterServe bool) Response { + return &fileResponse{r, files, headers, removeAfterServe} +} + +// Operation response +type operationResponse struct { + op *operation +} + +func (r *operationResponse) Render(w http.ResponseWriter) error { + _, err := r.op.Run() + if err != nil { + return err + } + + url, md, err := r.op.Render() + if err != nil { + return err + } + + body := api.ResponseRaw{ + Type: api.AsyncResponse, + Status: api.OperationCreated.String(), + StatusCode: int(api.OperationCreated), + Operation: url, + Metadata: md, + } + + w.Header().Set("Location", url) + w.WriteHeader(202) + + return util.WriteJSON(w, body, debug) +} + +func (r *operationResponse) String() string { + _, md, err := r.op.Render() + if err != nil { + return fmt.Sprintf("error: %s", err) + } + + return md.ID +} + +func OperationResponse(op *operation) Response { + return &operationResponse{op} +} + +// Forwarded operation response. +// +// Returned when the operation has been created on another node +type forwardedOperationResponse struct { + op *api.Operation + project string +} + +func (r *forwardedOperationResponse) Render(w http.ResponseWriter) error { + url := fmt.Sprintf("/%s/operations/%s", version.APIVersion, r.op.ID) + if r.project != "" { + url += fmt.Sprintf("?project=%s", r.project) + } + + body := api.ResponseRaw{ + Type: api.AsyncResponse, + Status: api.OperationCreated.String(), + StatusCode: int(api.OperationCreated), + Operation: url, + Metadata: r.op, + } + + w.Header().Set("Location", url) + w.WriteHeader(202) + + return util.WriteJSON(w, body, debug) +} + +func (r *forwardedOperationResponse) String() string { + return r.op.ID +} + +// ForwardedOperationResponse creates a response that forwards the metadata of +// an operation created on another node. +func ForwardedOperationResponse(project string, op *api.Operation) Response { + return &forwardedOperationResponse{ + op: op, + project: project, + } +} + +// Error response +type errorResponse struct { + code int + msg string +} + +func (r *errorResponse) String() string { + return r.msg +} + +func (r *errorResponse) Render(w http.ResponseWriter) error { + var output io.Writer + + buf := &bytes.Buffer{} + output = buf + var captured *bytes.Buffer + if debug { + captured = &bytes.Buffer{} + output = io.MultiWriter(buf, captured) + } + + err := json.NewEncoder(output).Encode(shared.Jmap{"type": api.ErrorResponse, "error": r.msg, "error_code": r.code}) + + if err != nil { + return err + } + + if debug { + shared.DebugJson(captured) + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("X-Content-Type-Options", "nosniff") + w.WriteHeader(r.code) + fmt.Fprintln(w, buf.String()) + + return nil +} + +func NotImplemented(err error) Response { + message := "not implemented" + if err != nil { + message = err.Error() + } + return &errorResponse{http.StatusNotImplemented, message} +} + +func NotFound(err error) Response { + message := "not found" + if err != nil { + message = err.Error() + } + return &errorResponse{http.StatusNotFound, message} +} + +func Forbidden(err error) Response { + message := "not authorized" + if err != nil { + message = err.Error() + } + return &errorResponse{http.StatusForbidden, message} +} + +func Conflict(err error) Response { + message := "already exists" + if err != nil { + message = err.Error() + } + return &errorResponse{http.StatusConflict, message} +} + +func Unavailable(err error) Response { + message := "unavailable" + if err != nil { + message = err.Error() + } + return &errorResponse{http.StatusServiceUnavailable, message} +} + +func BadRequest(err error) Response { + return &errorResponse{http.StatusBadRequest, err.Error()} +} + +func InternalError(err error) Response { + return &errorResponse{http.StatusInternalServerError, err.Error()} +} + +func PreconditionFailed(err error) Response { + return &errorResponse{http.StatusPreconditionFailed, err.Error()} +} + +/* + * SmartError returns the right error message based on err. + */ +func SmartError(err error) Response { + if err == nil { + return EmptySyncResponse + } + + switch errors.Cause(err) { + case os.ErrNotExist, db.ErrNoSuchObject: + if errors.Cause(err) != err { + return NotFound(err) + } + + return NotFound(nil) + case os.ErrPermission: + if errors.Cause(err) != err { + return Forbidden(err) + } + + return Forbidden(nil) + case db.ErrAlreadyDefined: + if errors.Cause(err) != err { + return Conflict(err) + } + + return Conflict(nil) + + default: + return InternalError(err) + } +} diff --git a/lxd-agent/state.go b/lxd-agent/state.go new file mode 100644 index 0000000000..e0d1806b08 --- /dev/null +++ b/lxd-agent/state.go @@ -0,0 +1,19 @@ +package main + +import "net/http" + +var stateCmd = APIEndpoint{ + Name: "state", + Path: "state", + + Get: APIEndpointAction{Handler: stateGet}, + Put: APIEndpointAction{Handler: statePut}, +} + +func stateGet(r *http.Request) Response { + return NotImplemented(nil) +} + +func statePut(r *http.Request) Response { + return NotImplemented(nil) +}
_______________________________________________ lxc-devel mailing list lxc-devel@lists.linuxcontainers.org http://lists.linuxcontainers.org/listinfo/lxc-devel