package db

import (
	"fmt"
	"io/ioutil"
	"os"

	"github.com/cheggaaa/pb"
	"github.com/jinzhu/gorm"
	"github.com/k0kubun/pp"
	c "github.com/kotakanbe/go-cve-dictionary/config"
	"github.com/kotakanbe/go-cve-dictionary/jvn"
	log "github.com/kotakanbe/go-cve-dictionary/log"
	"github.com/kotakanbe/go-cve-dictionary/models"
	"github.com/kotakanbe/go-cve-dictionary/nvd"
	// Required MySQL.  See http://jinzhu.me/gorm/database.html#connecting-to-a-database
	_ "github.com/jinzhu/gorm/dialects/mysql"
	_ "github.com/jinzhu/gorm/dialects/postgres"
	// Required SQLite3.
	_ "github.com/jinzhu/gorm/dialects/sqlite"
)

// Supported DB dialects.
const (
	dialectSqlite3    = "sqlite3"
	dialectMysql      = "mysql"
	dialectPostgreSQL = "postgres"
)

// RDBDriver is Driver for RDB
type RDBDriver struct {
	name string
	conn *gorm.DB
}

// Name return db name
func (r *RDBDriver) Name() string {
	return r.name
}

// NewRDB return RDB driver
func NewRDB(dbType, dbpath string, debugSQL bool) (driver *RDBDriver, err error) {
	driver = &RDBDriver{
		name: dbType,
	}

	log.Debugf("Opening DB (%s).", driver.Name())
	if err = driver.OpenDB(dbType, dbpath, debugSQL); err != nil {
		return
	}

	log.Debugf("Migrating DB (%s).", driver.Name())
	if err = driver.MigrateDB(); err != nil {
		return
	}
	return
}

// OpenDB opens Database
func (r *RDBDriver) OpenDB(dbType, dbPath string, debugSQL bool) (err error) {
	r.conn, err = gorm.Open(dbType, dbPath)
	if err != nil {
		err = fmt.Errorf("Failed to open DB. dbtype: %s, dbpath: %s, err: %s", dbType, dbPath, err)
		return
	}
	r.conn.LogMode(debugSQL)
	if r.name == dialectSqlite3 {
		r.conn.Exec("PRAGMA journal_mode=WAL;")
	}
	return
}

// MigrateDB migrates Database
func (r *RDBDriver) MigrateDB() error {
	if err := r.conn.AutoMigrate(
		&models.CveDetail{},
		&models.Jvn{},
		&models.Nvd{},
		&models.Reference{},
		&models.Cpe{},
	).Error; err != nil {
		return fmt.Errorf("Failed to migrate. err: %s", err)
	}

	errMsg := "Failed to create index. err: %s"
	if err := r.conn.Model(&models.CveDetail{}).
		AddIndex("idx_cve_detail_cveid", "cve_id").Error; err != nil {
		return fmt.Errorf(errMsg, err)
	}
	if err := r.conn.Model(&models.Nvd{}).
		AddIndex("idx_nvds_cve_detail_id", "cve_detail_id").Error; err != nil {
		return fmt.Errorf(errMsg, err)
	}
	if err := r.conn.Model(&models.Jvn{}).
		AddIndex("idx_jvns_cve_detail_id", "cve_detail_id").Error; err != nil {
		return fmt.Errorf(errMsg, err)
	}
	if err := r.conn.Model(&models.Cpe{}).
		AddIndex("idx_cpes_jvn_id", "jvn_id").Error; err != nil {
		return fmt.Errorf(errMsg, err)
	}
	if err := r.conn.Model(&models.Reference{}).
		AddIndex("idx_references_jvn_id", "jvn_id").Error; err != nil {
		return fmt.Errorf(errMsg, err)
	}
	if err := r.conn.Model(&models.Cpe{}).
		AddIndex("idx_cpes_nvd_id", "nvd_id").Error; err != nil {
		return fmt.Errorf(errMsg, err)
	}
	if err := r.conn.Model(&models.Cpe{}).
		AddIndex("idx_cpes_cpe_name", "cpe_name").Error; err != nil {
		return fmt.Errorf(errMsg, err)
	}
	if err := r.conn.Model(&models.Reference{}).
		AddIndex("idx_references_nvd_id", "nvd_id").Error; err != nil {
		return fmt.Errorf(errMsg, err)
	}

	return nil
}

// Get Select Cve information from DB.
func (r *RDBDriver) Get(cveID string) *models.CveDetail {
	// Avoid null slice being null in JSON
	emptyCveDetail := models.CveDetail{
		Nvd: models.Nvd{
			References: []models.Reference{},
			Cpes:       []models.Cpe{},
		},
		Jvn: models.Jvn{
			References: []models.Reference{},
			Cpes:       []models.Cpe{},
		},
	}

	c := models.CveDetail{}
	r.conn.Where(&models.CveDetail{CveID: cveID}).First(&c)

	if c.ID == 0 {
		// Avoid null slice being null in JSON
		return &emptyCveDetail
	}

	// JVN
	jvn := models.Jvn{}
	r.conn.Model(&c).Related(&jvn, "Jvn")
	c.Jvn = jvn

	if jvn.CveDetailID != 0 && jvn.ID != 0 {
		jvnRefs := []models.Reference{}
		r.conn.Model(&jvn).Related(&jvnRefs, "References")
		c.Jvn.References = jvnRefs

		// TODO commentout because JSON response size will be big. so Uncomment if needed.
		//  jvnCpes := []models.Cpe{}
		//  conn.Model(&jvn).Related(&jvnCpes, "Cpes")
		//  c.Jvn.Cpes = jvnCpes
		if c.Jvn.Cpes == nil {
			c.Jvn.Cpes = []models.Cpe{}
		}
	}

	// NVD
	nvd := models.Nvd{}
	r.conn.Model(&c).Related(&nvd, "Nvd")
	c.Nvd = nvd

	if nvd.CveDetailID != 0 && nvd.ID != 0 {
		nvdRefs := []models.Reference{}
		r.conn.Model(&nvd).Related(&nvdRefs, "References")
		c.Nvd.References = nvdRefs

		// TODO commentout because JSON response size will be big. so Uncomment if needed.
		//  nvdCpes := []models.Cpe{}
		//  conn.Model(&nvd).Related(&nvdCpes, "Cpes")
		//  c.Nvd.Cpes = nvdCpes
		if c.Nvd.Cpes == nil {
			c.Nvd.Cpes = []models.Cpe{}
		}
	}
	return &c
}

// GetMulti Select Cves information from DB.
func (r *RDBDriver) GetMulti(cveIDs []string) (cveDetails map[string]*models.CveDetail) {
	// TODO not implemented yet
	return cveDetails
}

// CloseDB close Database
func (r *RDBDriver) CloseDB() (err error) {
	if err = r.conn.Close(); err != nil {
		log.Errorf("Failed to close DB. Type: %s. err: %s", r.name, err)
		return
	}
	return
}

// GetByCpeName Select Cve information from DB.
func (r *RDBDriver) GetByCpeName(cpeName string) (details []*models.CveDetail) {
	cpes := []models.Cpe{}
	r.conn.Where(&models.Cpe{CpeName: cpeName}).Find(&cpes)

	for _, cpe := range cpes {
		var cveDetailID uint
		if cpe.JvnID != 0 {
			//TODO test check CPE name format of JVN table
			jvn := models.Jvn{}
			r.conn.Select("cve_detail_id").Where("ID = ?", cpe.JvnID).First(&jvn)
			cveDetailID = jvn.CveDetailID
		} else if cpe.NvdID != 0 {
			nvd := models.Nvd{}
			r.conn.Select("cve_detail_id").Where("ID = ?", cpe.NvdID).First(&nvd)
			cveDetailID = nvd.CveDetailID
		}

		cveDetail := models.CveDetail{}
		r.conn.Select("cve_id").Where("ID = ?", cveDetailID).First(&cveDetail)
		cveID := cveDetail.CveID
		details = append(details, r.Get(cveID))
	}
	return
}

// InsertJvn insert items fetched from JVN.
func (r *RDBDriver) InsertJvn(items []jvn.Item) error {
	log.Info("Inserting fetched CVEs...")

	cves := convertJvn(items)
	if err := r.insertIntoJvn(cves); err != nil {
		return err
	}
	return nil
}

// InsertIntoJvn inserts Cve Information into DB
func (r *RDBDriver) insertIntoJvn(cves []models.CveDetail) error {
	var err error
	var refreshedJvns []string
	bar := pb.New(len(cves))
	if c.Conf.Quiet {
		bar.Output = ioutil.Discard
	} else {
		bar.Output = os.Stderr
	}
	bar.Start()

	for chunked := range chunkSlice(cves, 10) {
		var tx *gorm.DB
		tx = r.conn.Begin()

		for _, c := range chunked {
			bar.Increment()

			// select old record.
			old := models.CveDetail{}
			c.Jvn.CveID = c.CveID
			result := tx.Where(&models.CveDetail{CveID: c.CveID}).First(&old)
			if result.RecordNotFound() || old.ID == 0 {
				if err = tx.Create(&c).Error; err != nil {
					tx.Rollback()
					return fmt.Errorf("Failed to insert. cve: %s, err: %s",
						pp.Sprintf("%v", c), err)
				}
				refreshedJvns = append(refreshedJvns, c.CveID)
				continue
			}

			if !result.RecordNotFound() {
				// select Jvn from db
				jvn := models.Jvn{}
				r.conn.Model(&old).Related(&jvn, "Jvn")

				if jvn.CveDetailID == 0 {
					c.Jvn.CveDetailID = old.ID
					if err = tx.Create(&c.Jvn).Error; err != nil {
						tx.Rollback()
						return fmt.Errorf("Failed to insert. cve: %s, err: %s",
							pp.Sprintf("%v", c.Jvn), err)
					}
					refreshedJvns = append(refreshedJvns, c.CveID)
					continue
				}

				// Refresh JVN Record.

				// skip if the record has already been in DB and not modified.
				if jvn.LastModifiedDate.Equal(c.Jvn.LastModifiedDate) ||
					jvn.LastModifiedDate.After(c.Jvn.LastModifiedDate) {
					//  log.Debugf("Not modified. old: %s", old.CveID)
					continue
				} else {
					log.Debugf("Newer record found. CveID: %s, old: %s, new: %s",
						c.CveID, jvn.LastModifiedDate, c.Jvn.LastModifiedDate)
				}

				// Delte old References
				refs := []models.Reference{}
				r.conn.Model(&jvn).Related(&refs, "References")
				for _, r := range refs {
					if err = tx.Unscoped().Delete(r).Error; err != nil {
						tx.Rollback()
						return errDelete(c, err)
					}
				}

				// Delete old Cpes
				cpes := []models.Cpe{}
				r.conn.Model(&jvn).Related(&cpes, "Cpes")
				for _, cpe := range cpes {
					if err = tx.Unscoped().Delete(cpe).Error; err != nil {
						tx.Rollback()
						return errDelete(c, err)
					}
				}

				// Delete old Jvn
				if err = tx.Unscoped().Delete(&jvn).Error; err != nil {
					tx.Rollback()
					return errDelete(c, err)
				}

				// Insert Jvn
				c.Jvn.CveDetailID = old.ID
				if err = tx.Create(&c.Jvn).Error; err != nil {
					tx.Rollback()
					return fmt.Errorf("Failed to insert. cve: %s, err: %s",
						pp.Sprintf("%v", c.Jvn), err)
				}
				refreshedJvns = append(refreshedJvns, c.CveID)
			}
		}
		tx.Commit()
	}
	bar.Finish()
	log.Infof("Refreshed %d Jvns.", len(refreshedJvns))
	log.Debugf("%v", refreshedJvns)
	return nil
}

func errDelete(c models.CveDetail, err error) error {
	return fmt.Errorf("Failed to delete old record. cve: %s, err: %s",
		pp.Sprintf("%v", c), err)
}

// CountNvd count nvd table
func (r *RDBDriver) CountNvd() (int, error) {
	var count int
	if err := r.conn.Model(&models.Nvd{}).Count(&count).Error; err != nil {
		return 0, err
	}
	return count, nil
}

// InsertNvd inserts CveInformation into DB
func (r *RDBDriver) InsertNvd(entries []nvd.Entry) error {
	log.Info("Inserting CVEs...")

	cves := convertNvd(entries)
	if err := r.insertIntoNvd(cves); err != nil {
		return err
	}
	return nil
}

// insertIntoNvd inserts CveInformation into DB
func (r *RDBDriver) insertIntoNvd(cves []models.CveDetail) error {
	var err error
	var refreshedNvds []string
	bar := pb.New(len(cves))
	if c.Conf.Quiet {
		bar.Output = ioutil.Discard
	} else {
		bar.Output = os.Stderr
	}
	bar.Start()

	for chunked := range chunkSlice(cves, 10) {
		var tx *gorm.DB

		tx = r.conn.Begin()

		for _, c := range chunked {
			bar.Increment()

			//TODO rename
			old := models.CveDetail{}
			c.Nvd.CveID = c.CveID

			// select old record.
			result := tx.Where(&models.CveDetail{CveID: c.CveID}).First(&old)
			if result.RecordNotFound() || old.ID == 0 {
				if err = tx.Create(&c).Error; err != nil {
					tx.Rollback()
					return fmt.Errorf("Failed to insert. cve: %s, err: %s",
						pp.Sprintf("%v", c), err)
				}
				refreshedNvds = append(refreshedNvds, c.CveID)
				continue
			}

			if !result.RecordNotFound() {
				// select Nvd from db
				nvd := models.Nvd{}
				r.conn.Model(&old).Related(&nvd, "Nvd")

				if nvd.CveDetailID == 0 {
					c.Nvd.CveDetailID = old.ID
					if err = tx.Create(&c.Nvd).Error; err != nil {
						tx.Rollback()
						return fmt.Errorf("Failed to insert. cve: %s, err: %s",
							pp.Sprintf("%v", c.Nvd), err)
					}
					refreshedNvds = append(refreshedNvds, c.CveID)
					continue
				}

				// Refresh to new NVD Record.

				// skip if the record has already been in DB and not modified.
				if nvd.LastModifiedDate.Equal(c.Nvd.LastModifiedDate) ||
					nvd.LastModifiedDate.After(c.Nvd.LastModifiedDate) {
					//  log.Debugf("Not modified. old: %s", old.CveID)
					continue
				} else {
					log.Debugf("newer Record found. CveID: %s, old: %s, new: %s",
						c.CveID, nvd.LastModifiedDate, c.Nvd.LastModifiedDate)
				}

				// Delte old References
				refs := []models.Reference{}
				r.conn.Model(&nvd).Related(&refs, "References")
				for _, r := range refs {
					if err = tx.Unscoped().Delete(r).Error; err != nil {
						tx.Rollback()
						return errDelete(c, err)
					}
				}

				// Delete old Cpes
				cpes := []models.Cpe{}
				r.conn.Model(&nvd).Related(&cpes, "Cpes")
				for _, cpe := range cpes {
					if err = tx.Unscoped().Delete(cpe).Error; err != nil {
						tx.Rollback()
						return errDelete(c, err)
					}
				}

				// Delete old Nvd
				if err = tx.Unscoped().Delete(&nvd).Error; err != nil {
					tx.Rollback()
					return errDelete(c, err)
				}

				// Insert Nvd
				c.Nvd.CveDetailID = old.ID
				if err = tx.Create(&c.Nvd).Error; err != nil {
					tx.Rollback()
					return fmt.Errorf("Failed to insert. cve: %s, err: %s",
						pp.Sprintf("%v", c.Nvd), err)
				}
				refreshedNvds = append(refreshedNvds, c.CveID)
			}
		}
		tx.Commit()
	}
	bar.Finish()

	log.Infof("Refreshed %d Nvds.", len(refreshedNvds))
	//  log.Debugf("%v", refreshedNvds)
	return nil
}
