package save

import (
	"crypto/rand"
	"fmt"
	"io/ioutil"
	"net/http"
	"os"
	"path/filepath"
	"strings"
	"sync/atomic"
	"testing"

	"github.com/hashicorp/consul/agent"
	"github.com/hashicorp/consul/api"
	"github.com/hashicorp/consul/lib"
	"github.com/hashicorp/consul/sdk/testutil"
	"github.com/mitchellh/cli"
	"github.com/stretchr/testify/require"
)

func TestSnapshotSaveCommand_noTabs(t *testing.T) {
	t.Parallel()
	if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
		t.Fatal("help has tabs")
	}
}

func TestSnapshotSaveCommand_Validation(t *testing.T) {
	t.Parallel()

	cases := map[string]struct {
		args   []string
		output string
	}{
		"no file": {
			[]string{},
			"Missing FILE argument",
		},
		"extra args": {
			[]string{"foo", "bar", "baz"},
			"Too many arguments",
		},
	}

	for name, tc := range cases {
		ui := cli.NewMockUi()
		c := New(ui)

		// Ensure our buffer is always clear
		if ui.ErrorWriter != nil {
			ui.ErrorWriter.Reset()
		}
		if ui.OutputWriter != nil {
			ui.OutputWriter.Reset()
		}

		code := c.Run(tc.args)
		if code == 0 {
			t.Errorf("%s: expected non-zero exit", name)
		}

		output := ui.ErrorWriter.String()
		if !strings.Contains(output, tc.output) {
			t.Errorf("%s: expected %q to contain %q", name, output, tc.output)
		}
	}
}

func TestSnapshotSaveCommand(t *testing.T) {
	t.Parallel()
	a := agent.NewTestAgent(t, t.Name(), ``)
	defer a.Shutdown()
	client := a.Client()

	ui := cli.NewMockUi()
	c := New(ui)

	dir := testutil.TempDir(t, "snapshot")
	defer os.RemoveAll(dir)

	file := filepath.Join(dir, "backup.tgz")
	args := []string{
		"-http-addr=" + a.HTTPAddr(),
		file,
	}

	code := c.Run(args)
	if code != 0 {
		t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String())
	}

	f, err := os.Open(file)
	if err != nil {
		t.Fatalf("err: %v", err)
	}
	defer f.Close()

	if err := client.Snapshot().Restore(nil, f); err != nil {
		t.Fatalf("err: %v", err)
	}
}

func TestSnapshotSaveCommand_TruncatedStream(t *testing.T) {
	t.Parallel()
	a := agent.NewTestAgent(t, t.Name(), ``)
	defer a.Shutdown()
	client := a.Client()

	// Seed it with 64K of random data just so we have something to work with.
	{
		blob := make([]byte, 64*1024)
		_, err := rand.Read(blob)
		require.NoError(t, err)

		_, err = client.KV().Put(&api.KVPair{Key: "blob", Value: blob}, nil)
		require.NoError(t, err)
	}

	// Do a manual snapshot so we can send back roughly reasonable data.
	var inputData []byte
	{
		rc, _, err := client.Snapshot().Save(nil)
		require.NoError(t, err)
		defer rc.Close()

		inputData, err = ioutil.ReadAll(rc)
		require.NoError(t, err)
	}

	var fakeResult atomic.Value

	// Run a fake webserver to pretend to be the snapshot API.
	fakeAddr, cleanup := lib.StartTestServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
		if req.URL.Path != "/v1/snapshot" {
			w.WriteHeader(http.StatusNotFound)
			return
		}
		if req.Method != "GET" {
			w.WriteHeader(http.StatusMethodNotAllowed)
			return
		}

		raw := fakeResult.Load()
		if raw == nil {
			w.WriteHeader(http.StatusNotFound)
			return
		}

		data := raw.([]byte)
		_, _ = w.Write(data)
	}))
	defer cleanup()

	dir := testutil.TempDir(t, "snapshot")
	defer os.RemoveAll(dir)

	for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
		t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
			// Lop off part of the end.
			data := inputData[0 : len(inputData)-removeBytes]

			fakeResult.Store(data)

			ui := cli.NewMockUi()
			c := New(ui)

			file := filepath.Join(dir, "backup.tgz")
			args := []string{
				"-http-addr=" + fakeAddr, // point to the fake
				file,
			}

			code := c.Run(args)
			require.Equal(t, 1, code, "expected non-zero exit")

			output := ui.ErrorWriter.String()
			require.Contains(t, output, "Error verifying snapshot file")
			require.Contains(t, output, "EOF")

			// file should not have been created

			_, err := os.Stat(file)
			require.Error(t, err, "file is not supposed to exist")
			require.True(t, os.IsNotExist(err), "file is not supposed to exist")

			// also check that the unverified inputs are gone as well
			_, err = os.Stat(file + ".unverified")
			require.Error(t, err, "unverified file is not supposed to exist")
			require.True(t, os.IsNotExist(err), "unverified file is not supposed to exist")
		})
	}
}
