package modelgen

import (
	"encoding/json"
	"fmt"
	"testing"
	"text/template"

	"github.com/ovn-org/libovsdb/ovsdb"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestNewTableTemplate(t *testing.T) {
	rawSchema := []byte(`
	{
		"name": "AtomicDB",
		"version": "0.0.0",
		"tables": {
			"atomicTable": {
				"columns": {
					"str": {
						"type": "string"
					},
					"int": {
						"type": "integer"
					},
					"float": {
						"type": "real"
					},
					"protocol": {
						"type": {"key": {"type": "string",
								 "enum": ["set", ["tcp", "udp", "sctp"]]},
								 "min": 0, "max": 1}},
					"event_type": {"type": {"key": {"type": "string",
													"enum": ["set", ["empty_lb_backends"]]}}}
				}
			}
		}
	}`)

	test := []struct {
		name      string
		extend    func(tmpl *template.Template, data TableTemplateData)
		expected  string
		err       bool
		formatErr bool
	}{
		{
			name: "normal",
			expected: `// Code generated by "libovsdb.modelgen"
// DO NOT EDIT.

package test

type (
	AtomicTableEventType = string
	AtomicTableProtocol  = string
)

const (
	AtomicTableEventTypeEmptyLbBackends AtomicTableEventType = "empty_lb_backends"
	AtomicTableProtocolTCP              AtomicTableProtocol  = "tcp"
	AtomicTableProtocolUDP              AtomicTableProtocol  = "udp"
	AtomicTableProtocolSCTP             AtomicTableProtocol  = "sctp"
)

// AtomicTable defines an object in atomicTable table
type AtomicTable struct {
	UUID      string                ` + "`" + `ovsdb:"_uuid"` + "`" + `
	EventType AtomicTableEventType  ` + "`" + `ovsdb:"event_type"` + "`" + `
	Float     float64               ` + "`" + `ovsdb:"float"` + "`" + `
	Int       int                   ` + "`" + `ovsdb:"int"` + "`" + `
	Protocol  []AtomicTableProtocol ` + "`" + `ovsdb:"protocol"` + "`" + `
	Str       string                ` + "`" + `ovsdb:"str"` + "`" + `
}
`,
		},
		{
			name: "no enums",
			extend: func(tmpl *template.Template, data TableTemplateData) {
				data.WithEnumTypes(false)
			},
			expected: `// Code generated by "libovsdb.modelgen"
// DO NOT EDIT.

package test

// AtomicTable defines an object in atomicTable table
type AtomicTable struct {
	UUID      string   ` + "`" + `ovsdb:"_uuid"` + "`" + `
	EventType string   ` + "`" + `ovsdb:"event_type"` + "`" + `
	Float     float64  ` + "`" + `ovsdb:"float"` + "`" + `
	Int       int      ` + "`" + `ovsdb:"int"` + "`" + `
	Protocol  []string ` + "`" + `ovsdb:"protocol"` + "`" + `
	Str       string   ` + "`" + `ovsdb:"str"` + "`" + `
}
`,
		},
		{
			name: "add fields using same data",
			extend: func(tmpl *template.Template, data TableTemplateData) {
				extra := `{{ define "extraFields" }} {{- $tableName := index . "TableName" }} {{ range $field := index . "Fields"  }}	Other{{ FieldName $field.Column }}  {{ FieldType $tableName $field.Column $field.Schema }}
{{ end }}
{{- end }}`
				_, err := tmpl.Parse(extra)
				if err != nil {
					panic(err)
				}
			},
			expected: `// Code generated by "libovsdb.modelgen"
// DO NOT EDIT.

package test

type (
	AtomicTableEventType = string
	AtomicTableProtocol  = string
)

const (
	AtomicTableEventTypeEmptyLbBackends AtomicTableEventType = "empty_lb_backends"
	AtomicTableProtocolTCP              AtomicTableProtocol  = "tcp"
	AtomicTableProtocolUDP              AtomicTableProtocol  = "udp"
	AtomicTableProtocolSCTP             AtomicTableProtocol  = "sctp"
)

// AtomicTable defines an object in atomicTable table
type AtomicTable struct {
	UUID      string                ` + "`" + `ovsdb:"_uuid"` + "`" + `
	EventType AtomicTableEventType  ` + "`" + `ovsdb:"event_type"` + "`" + `
	Float     float64               ` + "`" + `ovsdb:"float"` + "`" + `
	Int       int                   ` + "`" + `ovsdb:"int"` + "`" + `
	Protocol  []AtomicTableProtocol ` + "`" + `ovsdb:"protocol"` + "`" + `
	Str       string                ` + "`" + `ovsdb:"str"` + "`" + `

	OtherUUID      string
	OtherEventType string
	OtherFloat     float64
	OtherInt       int
	OtherProtocol  []string
	OtherStr       string
}
`,
		},
		{
			name: "add extra functions using extra data",
			extend: func(tmpl *template.Template, data TableTemplateData) {
				extra := `{{ define "postStructDefinitions" }}
func {{ index . "TestName" }} () string {
    return "{{ index . "StructName" }}"
} {{ end }}
`
				_, err := tmpl.Parse(extra)
				if err != nil {
					panic(err)
				}
				data["TestName"] = "TestFunc"
			},
			expected: `// Code generated by "libovsdb.modelgen"
// DO NOT EDIT.

package test

type (
	AtomicTableEventType = string
	AtomicTableProtocol  = string
)

const (
	AtomicTableEventTypeEmptyLbBackends AtomicTableEventType = "empty_lb_backends"
	AtomicTableProtocolTCP              AtomicTableProtocol  = "tcp"
	AtomicTableProtocolUDP              AtomicTableProtocol  = "udp"
	AtomicTableProtocolSCTP             AtomicTableProtocol  = "sctp"
)

// AtomicTable defines an object in atomicTable table
type AtomicTable struct {
	UUID      string                ` + "`" + `ovsdb:"_uuid"` + "`" + `
	EventType AtomicTableEventType  ` + "`" + `ovsdb:"event_type"` + "`" + `
	Float     float64               ` + "`" + `ovsdb:"float"` + "`" + `
	Int       int                   ` + "`" + `ovsdb:"int"` + "`" + `
	Protocol  []AtomicTableProtocol ` + "`" + `ovsdb:"protocol"` + "`" + `
	Str       string                ` + "`" + `ovsdb:"str"` + "`" + `
}

func TestFunc() string {
	return "AtomicTable"
}
`,
		},
		{
			name:      "add bad code",
			formatErr: true,
			extend: func(tmpl *template.Template, data TableTemplateData) {
				extra := `{{ define "preStructDefinitions" }}
WRONG FORMAT
{{ end }}
`
				_, err := tmpl.Parse(extra)
				if err != nil {
					panic(err)
				}
			},
		},
	}

	var schema ovsdb.DatabaseSchema
	err := json.Unmarshal(rawSchema, &schema)
	if err != nil {
		t.Fatal(err)
	}

	for _, tt := range test {
		t.Run(fmt.Sprintf("Table Test: %s", tt.name), func(t *testing.T) {
			fakeTable := "atomicTable"
			tmpl := NewTableTemplate()
			table := schema.Tables[fakeTable]
			data := GetTableTemplateData(
				"test",
				fakeTable,
				&table,
			)
			if tt.err {
				assert.NotNil(t, err)
			} else {
				if tt.extend != nil {
					tt.extend(tmpl, data)
				}
				for i := 0; i < 3; i++ {
					g, err := NewGenerator()
					require.NoError(t, err)
					b, err := g.Format(tmpl, data)
					if tt.formatErr {
						assert.NotNil(t, err)
					} else {
						require.NoError(t, err)
						assert.Equal(t, tt.expected, string(b))
					}
				}
			}
		})
	}
}

func TestFieldName(t *testing.T) {
	cases := []struct {
		in       string
		expected string
	}{
		{"foo", "Foo"},
	}
	for _, tt := range cases {
		if s := FieldName(tt.in); s != tt.expected {
			t.Fatalf("got %s, wanted %s", s, tt.expected)
		}
	}

}

func TestStructName(t *testing.T) {
	if s := StructName("Foo_Bar"); s != "FooBar" {
		t.Fatalf("got %s, wanted FooBar", s)
	}
}

/*
TODO: Write Test
func TestFieldType(t *testing.T) {
	tests := []struct {
		name string
		in   *ovsdb.ColumnSchema
		out  string
	}
	if got := FieldType(tt.args.column); got != tt.want {
		t.Errorf("FieldType() = %v, want %v", got, tt.want)
	}
}
*/

func TestAtomicType(t *testing.T) {
	tests := []struct {
		name string
		in   string
		out  string
	}{
		{"IntegerToInt", ovsdb.TypeInteger, "int"},
		{"RealToFloat", ovsdb.TypeReal, "float64"},
		{"BooleanToBool", ovsdb.TypeBoolean, "bool"},
		{"StringToString", ovsdb.TypeString, "string"},
		{"UUIDToString", ovsdb.TypeUUID, "string"},
		{"Invalid", "notAType", ""},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if got := AtomicType(tt.in); got != tt.out {
				t.Errorf("got %s, wanted %s", got, tt.out)
			}
		})
	}
}

func TestTag(t *testing.T) {
	if s := Tag("Foo_Bar"); s != "ovsdb:\"Foo_Bar\"" {
		t.Fatalf("got %s, wanted ovsdb:\"Foo_Bar\"", s)
	}
}

func TestFileName(t *testing.T) {
	if s := FileName("foo"); s != "foo.go" {
		t.Fatalf("got %s, wanted foo.go", s)
	}
}

func TestCamelCase(t *testing.T) {
	cases := []struct {
		in       string
		expected string
	}{
		{"foo_bar_baz", "FooBarBaz"},
		{"foo-bar-baz", "FooBarBaz"},
		{"foos-bars-bazs", "FoosBarsBazs"},
		{"ip_port_mappings", "IPPortMappings"},
		{"external_ids", "ExternalIDs"},
		{"ip_prefix", "IPPrefix"},
		{"dns_records", "DNSRecords"},
		{"logical_ip", "LogicalIP"},
		{"ip", "IP"},
	}
	for _, tt := range cases {
		if s := camelCase(tt.in); s != tt.expected {
			t.Fatalf("got %s, wanted %s", s, tt.expected)
		}
	}
}

func ExampleNewTableTemplate() {
	schemaString := []byte(`
	{
		"name": "MyDB",
		"version": "0.0.0",
		"tables": {
			"table1": {
				"columns": {
					"string_column": {
						"type": "string"
					},
					"some_integer": {
						"type": "integer"
					}
				}
			}
		}
	}`)
	var schema ovsdb.DatabaseSchema
	err := json.Unmarshal(schemaString, &schema)
	if err != nil {
		panic(err)
	}

	base := NewTableTemplate()
	data := GetTableTemplateData("mypackage", "table1", schema.Table("table1"))

	// Add a function at after the struct definition
	// It can access the default data values plus any extra field that is added to data
	_, err = base.Parse(`{{define "postStructDefinitions"}}
func (t {{ index . "StructName" }}) {{ index . "FuncName"}}() string {
    return "bar"
}{{end}}`)
	if err != nil {
		panic(err)
	}
	data["FuncName"] = "TestFunc"

	gen, err := NewGenerator(WithDryRun())
	if err != nil {
		panic(err)
	}
	err = gen.Generate("generated.go", base, data)
	if err != nil {
		panic(err)
	}
}
