// -*- Mode: Go; indent-tabs-mode: t -*-

/*
 * Copyright (C) 2023 Canonical Ltd
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

package seedwriter

import (
	"bufio"
	"bytes"
	"fmt"
	"io/ioutil"
	"os"
	"sort"
	"strconv"
	"strings"

	"github.com/snapcore/snapd/asserts"
	"github.com/snapcore/snapd/asserts/snapasserts"
	"github.com/snapcore/snapd/osutil"
	"github.com/snapcore/snapd/snap"
	"github.com/snapcore/snapd/strutil"
)

// ManifestSnapRevision represents a snap revision as noted
// in the seed manifest.
type ManifestSnapRevision struct {
	SnapName string
	Revision snap.Revision
}

func (s *ManifestSnapRevision) String() string {
	return fmt.Sprintf("%s %s", s.SnapName, s.Revision)
}

// ManifestValidationSet represents a validation set as noted
// in the seed manifest. A validation set can optionally be pinned,
// but the sequence will always be set to the sequence that was used
// during the image build.
type ManifestValidationSet struct {
	AccountID string
	Name      string
	Sequence  int
	Pinned    bool
	Snaps     []string
}

func newManifestValidationSet(vsa *asserts.ValidationSet, pinned bool) *ManifestValidationSet {
	vs := &ManifestValidationSet{
		AccountID: vsa.AccountID(),
		Name:      vsa.Name(),
		Sequence:  vsa.Sequence(),
		Pinned:    pinned,
	}
	return vs
}

func (s *ManifestValidationSet) String() string {
	if s.Pinned {
		return fmt.Sprintf("%s/%s=%d", s.AccountID, s.Name, s.Sequence)
	} else {
		return fmt.Sprintf("%s/%s %d", s.AccountID, s.Name, s.Sequence)
	}
}

func (s *ManifestValidationSet) Unique() string {
	return fmt.Sprintf("%s/%s", s.AccountID, s.Name)
}

func (vs *ManifestValidationSet) hasSnap(snapName string) bool {
	return strutil.ListContains(vs.Snaps, snapName)
}

// Represents the validation-sets and snaps that are used to build
// an image seed. The manifest will only allow adding entries once to support
// a pre-provided manifest.
// The seed.manifest generated by ubuntu-image contains entries in the following
// format:
// <account-id>/<name>=<sequence>
// <account-id>/<name> <sequence>
// <snap-name> <snap-revision>
type Manifest struct {
	revsAllowed map[string]*ManifestSnapRevision
	revsSeeded  map[string]*ManifestSnapRevision
	vsAllowed   map[string]*ManifestValidationSet
	vsSeeded    map[string]*ManifestValidationSet
}

func NewManifest() *Manifest {
	return &Manifest{
		revsAllowed: make(map[string]*ManifestSnapRevision),
		revsSeeded:  make(map[string]*ManifestSnapRevision),
		vsAllowed:   make(map[string]*ManifestValidationSet),
		vsSeeded:    make(map[string]*ManifestValidationSet),
	}
}

// MockManifest is stricly for unit tests, do not use for non-test code.
func MockManifest(revsAllowed, revsSeeded map[string]*ManifestSnapRevision, vsAllowed, vsSeeded map[string]*ManifestValidationSet) *Manifest {
	osutil.MustBeTestBinary("MockManifest can only be used in unit tests")

	sm := NewManifest()
	if revsAllowed != nil {
		sm.revsAllowed = revsAllowed
	}
	if revsSeeded != nil {
		sm.revsSeeded = revsSeeded
	}
	if vsAllowed != nil {
		sm.vsAllowed = vsAllowed
	}
	if vsSeeded != nil {
		sm.vsSeeded = vsSeeded
	}
	return sm
}

func (sm *Manifest) isControlledByValidationSet(snapName string) bool {
	for _, vs := range sm.vsSeeded {
		if vs.hasSnap(snapName) {
			return true
		}
	}
	return false
}

// SetAllowedSnapRevision adds a revision rule for the given snap name, meaning
// that any snap marked used through MarkSnapRevisionUsed will be validated against
// this rule. The manifest will only allow one revision per snap, meaning that any
// subsequent calls to this will be ignored.
func (sm *Manifest) SetAllowedSnapRevision(snapName string, revision snap.Revision) error {
	// Values that are higher than 0 indicate the revision comes from the store, and values
	// lower than 0 indicate the snap was sourced locally. We allow both in the seed.manifest as
	// long as the user can provide us with the correct snaps. The only number we won't accept is
	// 0.
	if revision.Unset() {
		return fmt.Errorf("snap revision for %q in manifest cannot be 0 (unset)", snapName)
	}

	if _, ok := sm.revsAllowed[snapName]; !ok {
		sm.revsAllowed[snapName] = &ManifestSnapRevision{
			SnapName: snapName,
			Revision: revision,
		}
	}
	return nil
}

// SetAllowedValidationSet adds a sequence rule for the given validation set, meaning
// that any validation set marked for use through MarkValidationSetUsed must match the
// given parameters. The manifest will only allow one sequence per validation set,
// meaning that any subsequent calls to this will be ignored.
func (sm *Manifest) SetAllowedValidationSet(accountID, name string, sequence int, pinned bool) error {
	if sequence <= 0 {
		return fmt.Errorf("cannot add allowed validation set %q for a unknown sequence",
			fmt.Sprintf("%s/%s", accountID, name))
	}

	vs := &ManifestValidationSet{
		AccountID: accountID,
		Name:      name,
		Sequence:  sequence,
		Pinned:    pinned,
	}
	if _, ok := sm.vsAllowed[vs.Unique()]; !ok {
		sm.vsAllowed[vs.Unique()] = vs
	}
	return nil
}

// MarkSnapRevisionSeeded attempts to mark a snap-revision as seeded in the manifest.
// The seeded revision will be validated against any previously allowed revisions set. It
// will also be validated against any revisions set in previously seeded validation sets.
func (sm *Manifest) MarkSnapRevisionSeeded(snapName string, revision snap.Revision) error {
	if rev, ok := sm.revsAllowed[snapName]; ok {
		// Allowed revision specified, it must match.
		if rev.Revision != revision {
			return fmt.Errorf("snap %q (%s) does not match the allowed revision %s",
				snapName, revision, rev.Revision)
		}
	}

	if rev, ok := sm.revsSeeded[snapName]; ok {
		// Already marked as seeding.
		return fmt.Errorf("cannot mark %q (%s) as seeded, it has already been marked seeded for revision %s",
			snapName, revision, rev.Revision)
	}

	sm.revsSeeded[snapName] = &ManifestSnapRevision{
		SnapName: snapName,
		Revision: revision,
	}
	return nil
}

// MarkValidationSetSeeded marks a validation-set as seeded. It verifies against any previously
// set rules by SetAllowedValidationSet, and sets up new rules based on the snaps defined in the
// validation set.
// This relies on validation-set assertions being marked here the moment they are fetched by
// the seedwriter, which should be done before the first call to MarkSnapRevisionSeeded.
func (sm *Manifest) MarkValidationSetSeeded(vsa *asserts.ValidationSet, pinned bool) error {
	vs := newManifestValidationSet(vsa, pinned)
	if _, ok := sm.vsSeeded[vs.Unique()]; ok {
		return fmt.Errorf("cannot mark validation set %q as seeded, it has already been marked as such",
			vs.Unique())
	}

	// Check against any pre-defined restrictions.
	if allowed, ok := sm.vsAllowed[vs.Unique()]; ok {
		if allowed.Sequence != vs.Sequence {
			return fmt.Errorf("sequence of %q (%d) does not match the allowed sequence (%d)",
				vs.Unique(), vs.Sequence, allowed.Sequence)
		}
		if allowed.Pinned != pinned {
			return fmt.Errorf("pinning of %q (%t) does not match the allowed pinning (%t)",
				vs.Unique(), pinned, allowed.Pinned)
		}
	}

	for _, sn := range vsa.Snaps() {
		// Record only snaps that have a presence set, and a revision specified
		if sn.Presence == asserts.PresenceInvalid {
			continue
		}
		if sn.Revision <= 0 {
			continue
		}

		// Update allowed snaps based on the validation-set.
		if err := sm.SetAllowedSnapRevision(sn.SnapName(), snap.R(sn.Revision)); err != nil {
			return err
		}

		// For book-keeping purposes we add the snap to the list of controlled
		// snap revisions by the validation set.
		vs.Snaps = append(vs.Snaps, sn.SnapName())
	}

	sm.vsSeeded[vs.Unique()] = vs
	return nil
}

// AllowedSnapRevision retrieves any specified revision rule for the snap
// name.
func (sm *Manifest) AllowedSnapRevision(snapName string) snap.Revision {
	// TODO: Check seeded validation-sets as well.
	if rev, ok := sm.revsAllowed[snapName]; ok {
		return rev.Revision
	}
	return snap.Revision{}
}

// AllowedValidationSets returns the validation sets specified as allowed.
func (sm *Manifest) AllowedValidationSets() []*ManifestValidationSet {
	var vss []*ManifestValidationSet
	for _, vs := range sm.vsAllowed {
		vss = append(vss, vs)
	}

	// Sort for test consistency
	sort.Slice(vss, func(i, j int) bool {
		return vss[i].Unique() < vss[j].Unique()
	})
	return vss
}

func parsePinnedValidationSet(sm *Manifest, vs string) error {
	acc, name, seq, err := snapasserts.ParseValidationSet(vs)
	if err != nil {
		return err
	}
	return sm.SetAllowedValidationSet(acc, name, seq, true)
}

func parseUnpinnedValidationSet(sm *Manifest, vs, seqStr string) error {
	acc, name, _, err := snapasserts.ParseValidationSet(vs)
	if err != nil {
		return err
	}
	seq, err := strconv.Atoi(seqStr)
	if err != nil {
		return fmt.Errorf("invalid validation-set sequence: %q", seqStr)
	}
	return sm.SetAllowedValidationSet(acc, name, seq, false)
}

func parseSnapRevision(sm *Manifest, sn, revStr string) error {
	if err := snap.ValidateName(sn); err != nil {
		return err
	}

	rev, err := snap.ParseRevision(revStr)
	if err != nil {
		return err
	}
	return sm.SetAllowedSnapRevision(sn, rev)
}

// ReadManifest reads a seed.manifest previously generated by Manifest.Write
// and returns a new Manifest structure reflecting the contents.
func ReadManifest(manifestFile string) (*Manifest, error) {
	f, err := os.Open(manifestFile)
	if err != nil {
		return nil, err
	}
	defer f.Close()

	sm := NewManifest()
	scanner := bufio.NewScanner(f)
	for scanner.Scan() {
		line := scanner.Text()
		if strings.HasPrefix(line, "#") {
			continue
		}
		if strings.HasPrefix(line, " ") {
			return nil, fmt.Errorf("line cannot start with any spaces: %q", line)
		}

		tokens := strings.Fields(line)

		switch {
		case len(tokens) == 1 && strings.Contains(tokens[0], "/"):
			// Pinned validation-set: <account-id>/<name>=<sequence>
			if err := parsePinnedValidationSet(sm, tokens[0]); err != nil {
				return nil, err
			}
		case len(tokens) == 2 && strings.Contains(tokens[0], "/"):
			// Unpinned validation-set: <account-id>/<name> <sequence>
			if err := parseUnpinnedValidationSet(sm, tokens[0], tokens[1]); err != nil {
				return nil, err
			}
		case len(tokens) == 2:
			// Snap revision: <snap> <revision>
			if err := parseSnapRevision(sm, tokens[0], tokens[1]); err != nil {
				return nil, err
			}
		default:
			return nil, fmt.Errorf("cannot parse line: %q", line)
		}
	}

	if err := scanner.Err(); err != nil {
		return nil, err
	}
	return sm, nil
}

// Write generates the seed.manifest contents from the provided map of
// snaps and their revisions, and stores them in the given file path.
func (sm *Manifest) Write(filePath string) error {
	if len(sm.revsSeeded) == 0 && len(sm.vsSeeded) == 0 {
		return nil
	}

	vsKeys := make([]string, 0, len(sm.vsSeeded))
	for k := range sm.vsSeeded {
		vsKeys = append(vsKeys, k)
	}
	sort.Strings(vsKeys)

	// Get the keys for the seeded snap revision map, and also
	// sort them by name for consistent output. At this stage we
	// also filter out snaps that are controlled by validation-sets.
	revisionKeys := make([]string, 0, len(sm.revsSeeded))
	for k := range sm.revsSeeded {
		if !sm.isControlledByValidationSet(k) {
			revisionKeys = append(revisionKeys, k)
		}
	}
	sort.Strings(revisionKeys)

	buf := bytes.NewBuffer(nil)
	for _, key := range vsKeys {
		fmt.Fprintf(buf, "%s\n", sm.vsSeeded[key])
	}
	for _, key := range revisionKeys {
		fmt.Fprintf(buf, "%s\n", sm.revsSeeded[key])
	}
	return ioutil.WriteFile(filePath, buf.Bytes(), 0755)
}
