#
# chroot.py : APIs for moblin-chroot
#
# Copyright 2009, Intel Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; version 2 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.

import os
import os.path
import sys
import time
import optparse
import logging
import shutil
import subprocess
import string
import glob
import rpm
import re
import shlex

import mic.appcreate as appcreate
import mic.imgcreate as imgcreate
import mic.imgconvert as imgconvert
from mic.imgcreate.fs import *

def perror(msg):
    print >> sys.stderr, "Error: %s" % msg

def pwarning(msg):
    print "Warning: %s" % msg

def pinfo(msg):
    print "Info: %s" % msg

def check_bind_mounts(chrootdir, bindmounts):
    chrootmounts = []
    mounts = bindmounts.split(";")
    for mount in mounts:
        if mount == "":
            continue
        srcdst = mount.split(":")
        if len(srcdst) == 1:
           srcdst.append("none")
        if not os.path.isdir(srcdst[0]):
            return False
        if srcdst[1] == "" or srcdst[1] == "none":
            srcdst[1] = None
        if srcdst[0] in ("/proc", "/proc/sys/fs/binfmt_misc", "/", "/sys", "/dev", "/dev/pts", "/dev/shm", "/var/lib/dbus", "/var/run/dbus"):
            continue
        if chrootdir:
            if not srcdst[1]:
                srcdst[1] = os.path.abspath(os.path.expanduser(srcdst[0]))
            else:
                srcdst[1] = os.path.abspath(os.path.expanduser(srcdst[1]))
            tmpdir = chrootdir + "/" + srcdst[1]
            if os.path.isdir(tmpdir):
                pwarning("dir %s has existed." % tmpdir)

    return True

def cleanup_mountdir(chrootdir, bindmounts):
    if bindmounts == "" or bindmounts == None:
        return
    chrootmounts = []
    mounts = bindmounts.split(";")
    for mount in mounts:
        if mount == "":
            continue
        srcdst = mount.split(":")
        if len(srcdst) == 1:
           srcdst.append("none")
        if srcdst[1] == "" or srcdst[1] == "none":
            srcdst[1] = srcdst[0]
        srcdst[1] = os.path.abspath(os.path.expanduser(srcdst[1]))
        tmpdir = chrootdir + "/" + srcdst[1]
        if os.path.isdir(tmpdir):
            if len(os.listdir(tmpdir)) == 0:
                shutil.rmtree(tmpdir, ignore_errors = True)
            else:
                pwarning("dir %s isn't empty." % tmpdir)

chroot_lockfd = -1
chroot_lock = ""
def setup_chrootenv(chrootdir, bindmounts = None):
    global chroot_lockfd, chroot_lock
    def get_bind_mounts(chrootdir, bindmounts):
        chrootmounts = []
        if bindmounts in ("", None):
            bindmounts = ""
        mounts = bindmounts.split(";")
        for mount in mounts:
            if mount == "":
                continue
            srcdst = mount.split(":")
            srcdst[0] = os.path.abspath(os.path.expanduser(srcdst[0]))
            if len(srcdst) == 1:
               srcdst.append("none")
            if not os.path.isdir(srcdst[0]):
                continue
            if srcdst[0] in ("/proc", "/proc/sys/fs/binfmt_misc", "/", "/sys", "/dev", "/dev/pts", "/dev/shm", "/var/lib/dbus", "/var/run/dbus"):
                pwarning("%s will be mounted by default." % srcdst[0])
                continue
            if srcdst[1] == "" or srcdst[1] == "none":
                srcdst[1] = None
            else:
                srcdst[1] = os.path.abspath(os.path.expanduser(srcdst[1]))
                if os.path.isdir(chrootdir + "/" + srcdst[1]):
                    pwarning("%s has existed in %s , skip it." % (srcdst[1], chrootdir))
                    continue
            chrootmounts.append(imgcreate.BindChrootMount(srcdst[0], chrootdir, srcdst[1]))
    
        """Default bind mounts"""
        chrootmounts.append(imgcreate.BindChrootMount("/proc", chrootdir, None))
        chrootmounts.append(imgcreate.BindChrootMount("/proc/sys/fs/binfmt_misc", chrootdir, None))
        chrootmounts.append(imgcreate.BindChrootMount("/sys", chrootdir, None))
        chrootmounts.append(imgcreate.BindChrootMount("/dev", chrootdir, None))
        chrootmounts.append(imgcreate.BindChrootMount("/dev/pts", chrootdir, None))
        chrootmounts.append(imgcreate.BindChrootMount("/dev/shm", chrootdir, None))
        chrootmounts.append(imgcreate.BindChrootMount("/var/lib/dbus", chrootdir, None))
        chrootmounts.append(imgcreate.BindChrootMount("/var/run/dbus", chrootdir, None))
        chrootmounts.append(imgcreate.BindChrootMount("/", chrootdir, "/parentroot", "ro"))
        for kernel in os.listdir("/lib/modules"):
            chrootmounts.append(imgcreate.BindChrootMount("/lib/modules/" + kernel, chrootdir, None, "ro"))
    
        return chrootmounts
    
    def bind_mount(chrootmounts):
        for b in chrootmounts:
            print "bind_mount: %s -> %s" % (b.src, b.dest)
            b.mount()

    def setup_resolv(chrootdir):
        shutil.copyfile("/etc/resolv.conf", chrootdir + "/etc/resolv.conf")

    globalmounts = get_bind_mounts(chrootdir, bindmounts)
    bind_mount(globalmounts)
    setup_resolv(chrootdir)
    shutil.copyfile("/etc/mtab", chrootdir + "/etc/mtab")
    chroot_lock = os.path.join(chrootdir, ".chroot.lock")
    chroot_lockfd = open(chroot_lock, "w")
    return globalmounts

def cleanup_mounts(chrootdir):
    checkpoints = ["/proc/sys/fs/binfmt_misc", "/proc", "/sys", "/dev/pts", "/dev/shm", "/dev", "/var/lib/dbus", "/var/run/dbus"]
    dev_null = os.open("/dev/null", os.O_WRONLY)
    umountcmd = find_binary_path("umount")
    for point in checkpoints:
        print point
        args = [ umountcmd, "-l", chrootdir + point ]
        subprocess.call(args, stdout=dev_null, stderr=dev_null)
    catcmd = find_binary_path("cat")
    args = [ catcmd, "/proc/mounts" ]
    proc_mounts = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=dev_null)
    outputs = proc_mounts.communicate()[0].strip().split("\n")
    for line in outputs:
        if line.find(os.path.abspath(chrootdir)) >= 0:
            if os.path.abspath(chrootdir) == line.split()[1]:
                continue
            point = line.split()[1]
            print point
            args = [ umountcmd, "-l", point ]
            ret = subprocess.call(args, stdout=dev_null, stderr=dev_null)
            if ret != 0:
                print "ERROR: failed to unmount %s" % point
                os.close(dev_null)
                return ret
    os.close(dev_null)
    return 0

def cleanup_chrootenv(chrootdir, bindmounts = None, globalmounts = []):
    global chroot_lockfd, chroot_lock
    def bind_unmount(chrootmounts):
        chrootmounts.reverse()
        for b in chrootmounts:
            print "bind_unmount: %s -> %s" % (b.src, b.dest)
            b.unmount()

    def cleanup_resolv(chrootdir):
        fd = open(chrootdir + "/etc/resolv.conf", "w")
        fd.truncate(0)
        fd.close()

    def kill_processes(chrootdir):
        for file in glob.glob("/proc/*/root"):
            try:
                if os.readlink(file) == chrootdir:
                    pid = int(file.split("/")[2])
                    os.kill(pid, 9)
            except:
                pass

    chroot_lockfd.close()
    bind_unmount(globalmounts)
    if not imgcreate.my_fuser(chroot_lock):
        tmpdir = chrootdir + "/parentroot"
        if len(os.listdir(tmpdir)) == 0:
            shutil.rmtree(tmpdir, ignore_errors = True)
        cleanup_resolv(chrootdir)
        if os.path.exists(chrootdir + "/etc/mtab"):
            os.unlink(chrootdir + "/etc/mtab")
        kill_processes(chrootdir)
    cleanup_mountdir(chrootdir, bindmounts)

class ImageMount:
    def __init__(self, img):
        self.img = img
        self.imgsize = imgcreate.get_file_size(self.img) * 1024L * 1024L
        self.imgmnt = imgcreate.mkdtemp()
        self.imgtype = imgcreate.get_image_type(img)
        self.extloop = None
        self.cleaned_up = False
        self.tmpoutdir = None

    def mount(self, targetdir):
        if self.imgtype == "liveusb":
            self.disk = imgcreate.SparseLoopbackDisk(self.img, self.imgsize)
            self.imgloop = appcreate.PartitionedMount({'/dev/sdb':self.disk}, self.imgmnt, skipformat = True)
            self.img_fstype = "vfat"
            self.imgloop.add_partition(self.imgsize/1024/1024, "/dev/sdb", "/", self.img_fstype, boot=False)
        elif self.imgtype == "livecd":
            self.imgloop = imgcreate.DiskMount(imgcreate.LoopbackDisk(self.img, 0), self.imgmnt)
        elif self.imgtype == "ext3fsimg":
            self.extmnt = targetdir
            self.extloop = imgcreate.ExtDiskMount(imgcreate.SparseLoopbackDisk(self.img, self.imgsize),
                                                  self.extmnt,
                                                  "ext3",
                                                  4096,
                                                  "ext3 label")
            try:
                self.extloop.mount()
                self.os_image = self.img
                return
            except imgcreate.MountError, e:
                self.extloop.cleanup()
                shutil.rmtree(self.extmnt, ignore_errors = True)
                shutil.rmtree(self.imgmnt, ignore_errors = True)
                raise imgcreate.CreatorError("Failed to loopback mount '%s' : %s" %
                                  (self.img, e))
        elif self.imgtype == "raw":
            partedcmd = find_binary_path("parted")

            self.disk = imgcreate.SparseLoopbackDisk(self.img, self.imgsize)
            self.extmnt = targetdir
            self.tmpoutdir = imgcreate.mkdtemp()
            self.imgloop = appcreate.PartitionedMount({'/dev/sdb':self.disk}, self.extmnt, skipformat = True)
            self.img_fstype = "ext3"
            self.extloop = None
            
            # Check the partitions from raw disk.
            p1 = subprocess.Popen([partedcmd,"-s",self.img,"unit","B","print"],
                                  stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
            out,err = p1.communicate()
            lines = out.strip().split("\n")
            
            root_mounted = False
            partition_mounts = 0

            for line in lines:
                line = line.strip()
                # Lines that start with number are the partitions,
                # because parted can be translated we can't refer to any text lines.
                if not line or not line[0].isdigit():
                    continue
                
                # Some vars have extra , as list seperator.
                line = line.replace(",","")
                
                # Example of parted output lines that are handled:
                # Number  Start        End          Size         Type     File system     Flags
                #  1      512B         3400000511B  3400000000B  primary
                #  2      3400531968B  3656384511B  255852544B   primary  linux-swap(v1)
                #  3      3656384512B  3720347647B  63963136B    primary  fat16           boot, lba

                partition_info = re.split("\s+",line)

                size = partition_info[3].split("B")[0]

                if len(partition_info) < 6:
                    # No filesystem can be found from partition line. Assuming
                    # btrfs, because that is the only MeeGo fs that parted does 
                    # not recognize properly.
                    # TODO: Can we make better assumption?
                    fstype = "btrfs"
                elif partition_info[5] in ["ext2","ext3","ext4","btrfs"]:
                    fstype = partition_info[5]
                elif partition_info[5] in ["fat16","fat32"]:
                    fstype = "vfat"
                elif "swap" in partition_info[5]:
                    fstype = "swap"
                else:
                    raise imgcreate.CreatorError("Could not recognize partition fs type '%s'." % partition_info[5])

                if not root_mounted and fstype in ["ext2","ext3","ext4","btrfs"]:
                    # TODO: Check that this is actually the valid root partition from /etc/fstab
                    mountpoint = "/"
                    root_mounted = True
                elif fstype == "swap":
                    mountpoint = "swap"
                else:
                    # TODO: Assing better mount points for the rest of the partitions.
                    partition_mounts += 1
                    mountpoint = "/media/partition_%d" % partition_mounts

                if "boot" in partition_info:
                    boot = True
                else:
                    boot = False
                
                print "Size: %s Bytes, fstype: %s, mountpoint: %s, boot: %s" % ( size, fstype, mountpoint, boot )
                # TODO: add_partition should take bytes as size parameter.
                self.imgloop.add_partition((int)(size)/1024/1024, "/dev/sdb", mountpoint, fstype = fstype, boot = boot)
            
            try:
                self.imgloop.mount()
                self.os_image = self.img
                return
            except imgcreate.MountError, e:
                self.imgloop.cleanup()
                raise imgcreate.CreatorError("Failed to loopback mount '%s' : %s" %
                                   (self.img, e))
                 
        else:
            shutil.rmtree(self.imgmnt, ignore_errors = True)
            raise imgcreate.CreatorError("I can't recognize image type of %s" % self.img)

        try:
            self.imgloop.mount()
        except imgcreate.MountError, e:
            self.imgloop.cleanup()
            raise imgcreate.CreatorError("Failed to loopback mount '%s' : %s" %
                               (self.img, e))

        # legacy LiveOS filesystem layout support, remove for F9 or F10
        if os.path.exists(self.imgmnt + "/squashfs.img"):
            squashimg = self.imgmnt + "/squashfs.img"
        else:
            squashimg = self.imgmnt + "/LiveOS/squashfs.img"

        self.tmpoutdir = imgcreate.mkdtemp()
        # unsquashfs requires outdir mustn't exist
        shutil.rmtree(self.tmpoutdir, ignore_errors = True)
        imgcreate.uncompress_squashfs(squashimg, self.tmpoutdir)

        # legacy LiveOS filesystem layout support, remove for F9 or F10
        if os.path.exists(self.tmpoutdir + "/os.img"):
            self.os_image = self.tmpoutdir + "/os.img"
        else:
            self.os_image = self.tmpoutdir + "/LiveOS/ext3fs.img"

        if not os.path.exists(self.os_image):
            self.imgloop.cleanup()
            shutil.rmtree(self.tmpoutdir, ignore_errors = True)
            shutil.rmtree(imgmnt, ignore_errors = True)
            raise imgcreate.CreatorError("'%s' is not a valid live CD ISO : neither "
                               "LiveOS/ext3fs.img nor os.img exist" %
                               base_on)

        #unpack image to target dir
        imgsize = imgcreate.get_file_size(self.os_image) * 1024L * 1024L
        self.extmnt = targetdir
        self.extloop = imgcreate.ExtDiskMount(imgcreate.SparseLoopbackDisk(self.os_image, imgsize),
                                              self.extmnt,
                                              "ext3",
                                              4096,
                                              "ext3 label")
        try:
            self.extloop.mount()
        except imgcreate.MountError, e:
            self.extloop.cleanup()
            shutil.rmtree(self.extmnt, ignore_errors = True)
            self.imgloop.cleanup()
            shutil.rmtree(self.tmpoutdir, ignore_errors = True)
            shutil.rmtree(self.imgmnt, ignore_errors = True)
            self.cleaned_up = True
            raise imgcreate.CreatorError("Failed to loopback mount '%s' : %s" %
                              (os_image, e))
    def cleanup(self):
        if self.cleaned_up:
            return
        if self.extloop:
            self.extloop.cleanup()
        if self.imgtype != "ext3fsimg":
            self.imgloop.cleanup()
            if self.tmpoutdir and os.path.isdir(self.tmpoutdir):
                shutil.rmtree(self.tmpoutdir, ignore_errors = True)
        shutil.rmtree(self.imgmnt, ignore_errors = True)

def chroot(chrootdir, bindmounts = None, execute = "/bin/bash"):
    def mychroot():
        os.chroot(chrootdir)
        os.chdir("/")

    dev_null = os.open("/dev/null", os.O_WRONLY)
    files_to_check = ["/bin/bash", "/sbin/init"]
    
    architecture_found = False

    """ Register statically-linked qemu-arm if it is an ARM fs """
    qemu_emulator = None

    for ftc in files_to_check:
        ftc = "%s/%s" % (chrootdir,ftc)
        
        # Return code of 'file' is "almost always" 0 based on some man pages
        # so we need to check the file existance first.
        if not os.path.exists(ftc):
            continue

        filecmd = find_binary_path("file")
        initp1 = subprocess.Popen([filecmd, ftc], stdout=subprocess.PIPE, stderr=dev_null)
        fileOutput = initp1.communicate()[0].strip().split("\n")
        
        for i in range(len(fileOutput)):
            if fileOutput[i].find("ARM") > 0:
                qemu_emulator = imgcreate.setup_qemu_emulator(chrootdir, "arm")
                architecture_found = True
                break
            if fileOutput[i].find("Intel") > 0:
                architecture_found = True
                break
                
        if architecture_found:
            break
                
    os.close(dev_null)
    if not architecture_found:
        raise imgcreate.CreatorError("Failed to get architecture from any of the following files %s from chroot." % files_to_check)

    try:
        print "Launching shell. Exit to continue."
        print "----------------------------------"
        globalmounts = setup_chrootenv(chrootdir, bindmounts)
        args = shlex.split(execute)
        subprocess.call(args, preexec_fn = mychroot)
    except OSError, (err, msg):
        raise imgcreate.CreatorError("Failed to chroot: %s" % msg)
    finally:
        cleanup_chrootenv(chrootdir, bindmounts, globalmounts)
        if qemu_emulator:
            os.unlink(chrootdir + qemu_emulator)        
