# Copyright 2019 VMware, Inc.
# All rights reserved. -- VMware Confidential

"""Low-level block device interface.

Definitions:

   sector   Physical unit of storage, typically 512 bytes for 512n devices, or
            4096 bytes for 4Kn devices.

   block    A "chunk" of data, encompassing one or several consecutive sectors.

   lba      A Logical Block Address (LBA) specifies the linear address of a disk
            sector, with the first sector being at LBA 0, the second at LBA 1,
            etc.

   chs      Cylinder-Head-Sector (C/H/S) addressing is the legacy method to
            address physical disk sectors. CHS addressing is mostly obsolete
            since it was superseded in all modern storage devices by LBA
            addressing. A few firmware/OS standards (e.g. BIOS, MBR, ...) still
            rely on CHS addressing for backward compatibility purposes.
"""
from fcntl import ioctl
import os
from struct import unpack

class CHS(tuple):
   """Tuple object representing a C/H/S sector address.
   """

   def __new__(cls, cylinder, head, sector):
      assert cylinder >= 0, "invalid CHS (cylinder=%d < 0)" % cylinder
      assert head >= 0, "invalid CHS (head=%d < 0)" % head
      assert sector >= 1, "invalid CHS (sector=%d < 1)" % sector
      return super().__new__(cls, (cylinder, head, sector))

   @property
   def cylinder(self):
      return self[0]

   @property
   def head(self):
      return self[1]

   @property
   def sector(self):
      return self[2]

   def toLba(self, numHeads, sectorsPerTrack):
      """Convert this CHS address to LBA.

      Per BIOS/ATA specs, conversion from CHS to LBA relies on the following
      formula:
      LBA = (C * numHeads + H) * sectorsPerTrack + (S − 1)
      """
      return ((self.cylinder * numHeads + self.head) *
              sectorsPerTrack + (self.sector - 1))

   @classmethod
   def fromLba(cls, lba, numHeads, sectorsPerTrack):
      """Convert an LBA address to CHS.
      """
      cylinder = lba // (numHeads * sectorsPerTrack)
      head = (lba // sectorsPerTrack) % numHeads
      sector = lba % sectorsPerTrack + 1
      return cls(cylinder, head, sector)


class BlockDev(object):
   """Class to read and write from/to a UNIX-style block device.
   """

   def __init__(self, path, sectorSize, numSectors, numHeads=None,
                sectorsPerTrack=None, isRegularFile=False):
      self.path = path
      self.sectorSize = sectorSize
      self.numSectors = numSectors
      self.numHeads = numHeads
      self.sectorsPerTrack = sectorsPerTrack
      self._fd = None
      self._isRegularFile = isRegularFile

      if not self._isRegularFile:
         self.open(os.O_RDONLY)
         try:
            self._getGeometry()
         finally:
            self.close()

   @property
   def lastLba(self):
      """Last valid LBA.
      """
      return self.numSectors - 1

   def bytesToLba(self, byteOffset):
      """Convert a size/offset from bytes to LBA.
      """
      assert byteOffset % self.sectorSize == 0, \
         ("%u: invalid sector offset cannot be converted to LBA "
          "(not a multiple of sector size %u)" % (byteOffset, self.sectorSize))
      return byteOffset // self.sectorSize

   def lbaToBytes(self, lba):
      """Convert an LBA to a byte size/offset.
      """
      return lba * self.sectorSize

   def lbaToChs(self, lba):
      """Convert an LBA to CHS.
      """
      return CHS.fromLba(lba, self.numHeads, self.sectorsPerTrack)

   @property
   def sizeInMB(self):
      """Device size in MB.
      """
      return self.numSectors * self.sectorSize // (1024 * 1024)

   def open(self, flags):
      """Open this disk for reading or writing.
      """

      if not self._isRegularFile:
         # use O_DIRECT | O_SYNC to force writes to disk.
         flags |= os.O_DIRECT | os.O_SYNC

      self._fd = os.open(self.path, flags)

   def close(self):
      """Close the file handle to this disk.
      """
      os.close(self._fd)
      self._fd = None

   def readBlock(self, offset, count):
      """Read disk blocks.
      """
      os.lseek(self._fd, offset * self.sectorSize, os.SEEK_SET)
      data = os.read(self._fd, count * self.sectorSize)
      nBytes = len(data)
      assert nBytes % self.sectorSize == 0, \
         ("%s: incomplete block read (%u not a multiple of sector size %u)" %
          (self.path, nBytes, count * self.sectorSize))
      return data

   def writeBlock(self, offset, buf):
      """Write disk blocks.
      """
      assert len(buf) % self.sectorSize == 0, \
         ("invalid write request (%s: input buffer size %u is not a multiple "
          "of sector size %u" % (self.path, len(buf), self.sectorSize))
      os.lseek(self._fd, offset * self.sectorSize, os.SEEK_SET)
      nBytes = os.write(self._fd, buf)
      assert nBytes % self.sectorSize == 0, \
         ("%s: incomplete block write (%u not a multiple of sector size %u)" %
          (self.path, nBytes, count * self.sectorSize))
      return nBytes

   def eraseBlocks(self, offset, count):
      """Erase disk blocks at given offset.
         @param offset specifies the starting sector to erase
         @param count  number of sectors to erase
      """
      # Write in small chunks to keep memory footprint sane
      bufSectors = 32
      buf = b'\x00' * self.sectorSize * bufSectors
      while count > 0:
         if count < bufSectors:
            bufSectors = count
            buf = buf[0:count*self.sectorSize]
         count -= bufSectors
         self.writeBlock(offset, buf)
         offset += bufSectors

   def _getGeometry(self):
      """Retrieve this disk's geometry.
      """
      IOCTL_HDIO_GETGEO = 0x301
      SIZEOF_HDIO_GETGEO_RESULT = 8

      geo = bytearray(SIZEOF_HDIO_GETGEO_RESULT)
      ioctl(self._fd, IOCTL_HDIO_GETGEO, geo)

      numHeads, sectorsPerTrack, numCylinders, _ = unpack("<BBHI", geo)
      if self.numSectors is None:
         # There are cases (such as with virtual disks) where the total number
         # of sectors reported by the OS is not a strict multiple of the CHS
         # geometry. We always trust the OS, and only resort to the following
         # calculation if the OS failed to set a valid total number of sectors.
         self.numSectors = numCylinders * numHeads * sectorsPerTrack

      self.numHeads = numHeads
      self.sectorsPerTrack = sectorsPerTrack


class Partition(object):
   """Class that represents a disk partition.
   """

   def __init__(self, partNum, fsType, start, end, label, guid=None,
                bootable=False, uuid=None):
      if start > end:
         raise ValueError("invalid partition boundaries (start: %u > end: %u)" %
                          (start, end))

      self.num = partNum
      self.fsType = fsType
      self.start = start
      self.end = end
      self.label = label
      self.guid = guid
      self.bootable = bootable
      self.uuid = uuid

   @property
   def numSectors(self):
      """Size of this partition, in number of sectors.
      """
      return self.end - self.start + 1
