Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 183e56e8 authored by Tao Bao's avatar Tao Bao
Browse files

releasetools: Reduce memory footprint for BBOTA generation.

The major issue with the existing implementation is unnecessarily
holding too much data in memory, such as HashBlocks() which first reads
in *all* the data to a list before hashing. We can leverage generator
functions to stream such operations.

This CL makes the following changes to reduce the peak memory use.
 - Adding RangeSha1() and WriteRangeDataToFd() to Image classes. These
   functions perform the operations on-the-fly.
 - Caching the computed SHA-1 values for a Transfer instance.

As a result, this CL reduces the peak memory use by ~80% (e.g. reducing
from 5.85GB to 1.16GB for the same incremental, as shown by "Maximum
resident set size" from `/usr/bin/time -v`). It also effectively
improves the (package generation) performance by ~30%.

Bug: 35768998
Bug: 32312123
Test: Generating the same incremental w/ and w/o the CL give identical
      output packages.
Change-Id: Ia5c6314b41da73dd6fe1dbe2ca81bbd89b517cec
parent 8e022843
Loading
Loading
Loading
Loading
+148 −109
Original line number Diff line number Diff line
@@ -24,8 +24,8 @@ import os
import os.path
import re
import subprocess
import sys
import threading
import tempfile

from collections import deque, OrderedDict
from hashlib import sha1
@@ -35,69 +35,67 @@ from rangelib import RangeSet
__all__ = ["EmptyImage", "DataImage", "BlockImageDiff"]


def compute_patch(src, tgt, imgdiff=False):
  srcfd, srcfile = tempfile.mkstemp(prefix="src-")
  tgtfd, tgtfile = tempfile.mkstemp(prefix="tgt-")
  patchfd, patchfile = tempfile.mkstemp(prefix="patch-")
  os.close(patchfd)
def compute_patch(srcfile, tgtfile, imgdiff=False):
  patchfile = common.MakeTempFile(prefix="patch-")

  try:
    with os.fdopen(srcfd, "wb") as f_src:
      for p in src:
        f_src.write(p)

    with os.fdopen(tgtfd, "wb") as f_tgt:
      for p in tgt:
        f_tgt.write(p)
    try:
      os.unlink(patchfile)
    except OSError:
      pass
  if imgdiff:
      p = subprocess.call(["imgdiff", "-z", srcfile, tgtfile, patchfile],
                          stdout=open("/dev/null", "a"),
    p = subprocess.call(
        ["imgdiff", "-z", srcfile, tgtfile, patchfile],
        stdout=open(os.devnull, 'w'),
        stderr=subprocess.STDOUT)
  else:
      p = subprocess.call(["bsdiff", srcfile, tgtfile, patchfile])
    p = subprocess.call(
        ["bsdiff", srcfile, tgtfile, patchfile],
        stdout=open(os.devnull, 'w'),
        stderr=subprocess.STDOUT)

  if p:
    raise ValueError("diff failed: " + str(p))

  with open(patchfile, "rb") as f:
    return f.read()
  finally:
    try:
      os.unlink(srcfile)
      os.unlink(tgtfile)
      os.unlink(patchfile)
    except OSError:
      pass


class Image(object):
  def RangeSha1(self, ranges):
    raise NotImplementedError

  def ReadRangeSet(self, ranges):
    raise NotImplementedError

  def TotalSha1(self, include_clobbered_blocks=False):
    raise NotImplementedError

  def WriteRangeDataToFd(self, ranges, fd):
    raise NotImplementedError


class EmptyImage(Image):
  """A zero-length image."""
  blocksize = 4096
  care_map = RangeSet()
  clobbered_blocks = RangeSet()
  extended = RangeSet()
  total_blocks = 0
  file_map = {}

  def __init__(self):
    self.blocksize = 4096
    self.care_map = RangeSet()
    self.clobbered_blocks = RangeSet()
    self.extended = RangeSet()
    self.total_blocks = 0
    self.file_map = {}

  def RangeSha1(self, ranges):
    return sha1().hexdigest()

  def ReadRangeSet(self, ranges):
    return ()

  def TotalSha1(self, include_clobbered_blocks=False):
    # EmptyImage always carries empty clobbered_blocks, so
    # include_clobbered_blocks can be ignored.
    assert self.clobbered_blocks.size() == 0
    return sha1().hexdigest()

  def WriteRangeDataToFd(self, ranges, fd):
    raise ValueError("Can't write data from EmptyImage to file")


class DataImage(Image):
  """An image wrapped around a single string of data."""
@@ -160,23 +158,39 @@ class DataImage(Image):
    if clobbered_blocks:
      self.file_map["__COPY"] = RangeSet(data=clobbered_blocks)

  def _GetRangeData(self, ranges):
    for s, e in ranges:
      yield self.data[s*self.blocksize:e*self.blocksize]

  def RangeSha1(self, ranges):
    h = sha1()
    for data in self._GetRangeData(ranges):
      h.update(data)
    return h.hexdigest()

  def ReadRangeSet(self, ranges):
    return [self.data[s*self.blocksize:e*self.blocksize] for (s, e) in ranges]
    return [self._GetRangeData(ranges)]

  def TotalSha1(self, include_clobbered_blocks=False):
    if not include_clobbered_blocks:
      ranges = self.care_map.subtract(self.clobbered_blocks)
      return sha1(self.ReadRangeSet(ranges)).hexdigest()
      return self.RangeSha1(self.care_map.subtract(self.clobbered_blocks))
    else:
      return sha1(self.data).hexdigest()

  def WriteRangeDataToFd(self, ranges, fd):
    for data in self._GetRangeData(ranges):
      fd.write(data)


class Transfer(object):
  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, style, by_id):
  def __init__(self, tgt_name, src_name, tgt_ranges, src_ranges, tgt_sha1,
               src_sha1, style, by_id):
    self.tgt_name = tgt_name
    self.src_name = src_name
    self.tgt_ranges = tgt_ranges
    self.src_ranges = src_ranges
    self.tgt_sha1 = tgt_sha1
    self.src_sha1 = src_sha1
    self.style = style
    self.intact = (getattr(tgt_ranges, "monotonic", False) and
                   getattr(src_ranges, "monotonic", False))
@@ -251,6 +265,9 @@ class HeapItem(object):
#      Implementations are free to break up the data into list/tuple
#      elements in any way that is convenient.
#
#    RangeSha1(): a function that returns (as a hex string) the SHA-1
#      hash of all the data in the specified range.
#
#    TotalSha1(): a function that returns (as a hex string) the SHA-1
#      hash of all the data in the image (ie, all the blocks in the
#      care_map minus clobbered_blocks, or including the clobbered
@@ -332,15 +349,6 @@ class BlockImageDiff(object):
    self.ComputePatches(prefix)
    self.WriteTransfers(prefix)

  def HashBlocks(self, source, ranges): # pylint: disable=no-self-use
    data = source.ReadRangeSet(ranges)
    ctx = sha1()

    for p in data:
      ctx.update(p)

    return ctx.hexdigest()

  def WriteTransfers(self, prefix):
    def WriteSplitTransfers(out, style, target_blocks):
      """Limit the size of operand in command 'new' and 'zero' to 1024 blocks.
@@ -397,7 +405,7 @@ class BlockImageDiff(object):
          stashed_blocks += sr.size()
          out.append("stash %d %s\n" % (sid, sr.to_string_raw()))
        else:
          sh = self.HashBlocks(self.src, sr)
          sh = self.src.RangeSha1(sr)
          if sh in stashes:
            stashes[sh] += 1
          else:
@@ -429,7 +437,7 @@ class BlockImageDiff(object):
        mapped_stashes = []
        for stash_raw_id, sr in xf.use_stash:
          unstashed_src_ranges = unstashed_src_ranges.subtract(sr)
          sh = self.HashBlocks(self.src, sr)
          sh = self.src.RangeSha1(sr)
          sr = xf.src_ranges.map_within(sr)
          mapped_stashes.append(sr)
          if self.version == 2:
@@ -515,7 +523,7 @@ class BlockImageDiff(object):

            out.append("%s %s %s %s\n" % (
                xf.style,
                self.HashBlocks(self.tgt, xf.tgt_ranges),
                xf.tgt_sha1,
                xf.tgt_ranges.to_string_raw(), src_str))
          total += tgt_size
      elif xf.style in ("bsdiff", "imgdiff"):
@@ -542,8 +550,8 @@ class BlockImageDiff(object):
          out.append("%s %d %d %s %s %s %s\n" % (
              xf.style,
              xf.patch_start, xf.patch_len,
              self.HashBlocks(self.src, xf.src_ranges),
              self.HashBlocks(self.tgt, xf.tgt_ranges),
              xf.src_sha1,
              xf.tgt_sha1,
              xf.tgt_ranges.to_string_raw(), src_str))
        total += tgt_size
      elif xf.style == "zero":
@@ -574,8 +582,7 @@ class BlockImageDiff(object):
                   stash_threshold)

    if self.version >= 3:
      self.touched_src_sha1 = self.HashBlocks(
          self.src, self.touched_src_ranges)
      self.touched_src_sha1 = self.src.RangeSha1(self.touched_src_ranges)

    # Zero out extended blocks as a workaround for bug 20881595.
    if self.tgt.extended:
@@ -674,7 +681,7 @@ class BlockImageDiff(object):
        if self.version == 2:
          stashed_blocks_after += sr.size()
        else:
          sh = self.HashBlocks(self.src, sr)
          sh = self.src.RangeSha1(sr)
          if sh not in stashes:
            stashed_blocks_after += sr.size()

@@ -731,7 +738,7 @@ class BlockImageDiff(object):
          stashed_blocks -= sr.size()
          heapq.heappush(free_stash_ids, sid)
        else:
          sh = self.HashBlocks(self.src, sr)
          sh = self.src.RangeSha1(sr)
          assert sh in stashes
          stashes[sh] -= 1
          if stashes[sh] == 0:
@@ -745,10 +752,10 @@ class BlockImageDiff(object):

  def ComputePatches(self, prefix):
    print("Reticulating splines...")
    diff_q = []
    diff_queue = []
    patch_num = 0
    with open(prefix + ".new.dat", "wb") as new_f:
      for xf in self.transfers:
      for index, xf in enumerate(self.transfers):
        if xf.style == "zero":
          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
          print("%10d %10d (%6.2f%%) %7s %s %s" % (
@@ -756,17 +763,13 @@ class BlockImageDiff(object):
              str(xf.tgt_ranges)))

        elif xf.style == "new":
          for piece in self.tgt.ReadRangeSet(xf.tgt_ranges):
            new_f.write(piece)
          self.tgt.WriteRangeDataToFd(xf.tgt_ranges, new_f)
          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
          print("%10d %10d (%6.2f%%) %7s %s %s" % (
              tgt_size, tgt_size, 100.0, xf.style,
              xf.tgt_name, str(xf.tgt_ranges)))

        elif xf.style == "diff":
          src = self.src.ReadRangeSet(xf.src_ranges)
          tgt = self.tgt.ReadRangeSet(xf.tgt_ranges)

          # We can't compare src and tgt directly because they may have
          # the same content but be broken up into blocks differently, eg:
          #
@@ -775,20 +778,11 @@ class BlockImageDiff(object):
          # We want those to compare equal, ideally without having to
          # actually concatenate the strings (these may be tens of
          # megabytes).

          src_sha1 = sha1()
          for p in src:
            src_sha1.update(p)
          tgt_sha1 = sha1()
          tgt_size = 0
          for p in tgt:
            tgt_sha1.update(p)
            tgt_size += len(p)

          if src_sha1.digest() == tgt_sha1.digest():
          if xf.src_sha1 == xf.tgt_sha1:
            # These are identical; we don't need to generate a patch,
            # just issue copy commands on the device.
            xf.style = "move"
            tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
            if xf.src_ranges != xf.tgt_ranges:
              print("%10d %10d (%6.2f%%) %7s %s %s (from %s)" % (
                  tgt_size, tgt_size, 100.0, xf.style,
@@ -815,38 +809,64 @@ class BlockImageDiff(object):
                       xf.tgt_name.split(".")[-1].lower()
                       in ("apk", "jar", "zip"))
            xf.style = "imgdiff" if imgdiff else "bsdiff"
            diff_q.append((tgt_size, src, tgt, xf, patch_num))
            diff_queue.append((index, imgdiff, patch_num))
            patch_num += 1

        else:
          assert False, "unknown style " + xf.style

    if diff_q:
    if diff_queue:
      if self.threads > 1:
        print("Computing patches (using %d threads)..." % (self.threads,))
      else:
        print("Computing patches...")
      diff_q.sort()

      patches = [None] * patch_num
      diff_total = len(diff_queue)
      patches = [None] * diff_total

      # TODO: Rewrite with multiprocessing.ThreadPool?
      # Using multiprocessing doesn't give additional benefits, due to the
      # pattern of the code. The diffing work is done by subprocess.call, which
      # already runs in a separate process (not affected much by the GIL -
      # Global Interpreter Lock). Using multiprocess also requires either a)
      # writing the diff input files in the main process before forking, or b)
      # reopening the image file (SparseImage) in the worker processes. Doing
      # neither of them further improves the performance.
      lock = threading.Lock()
      def diff_worker():
        while True:
          with lock:
            if not diff_q:
            if not diff_queue:
              return
            tgt_size, src, tgt, xf, patchnum = diff_q.pop()
          patch = compute_patch(src, tgt, imgdiff=(xf.style == "imgdiff"))
          size = len(patch)
            xf_index, imgdiff, patch_index = diff_queue.pop()

          xf = self.transfers[xf_index]
          src_ranges = xf.src_ranges
          tgt_ranges = xf.tgt_ranges

          # Needs lock since WriteRangeDataToFd() is stateful (calling seek).
          with lock:
            patches[patchnum] = (patch, xf)
            print("%10d %10d (%6.2f%%) %7s %s %s %s" % (
                size, tgt_size, size * 100.0 / tgt_size, xf.style,
                xf.tgt_name if xf.tgt_name == xf.src_name else (
                    xf.tgt_name + " (from " + xf.src_name + ")"),
                str(xf.tgt_ranges), str(xf.src_ranges)))
            src_file = common.MakeTempFile(prefix="src-")
            with open(src_file, "wb") as fd:
              self.src.WriteRangeDataToFd(src_ranges, fd)

            tgt_file = common.MakeTempFile(prefix="tgt-")
            with open(tgt_file, "wb") as fd:
              self.tgt.WriteRangeDataToFd(tgt_ranges, fd)

          try:
            patch = compute_patch(src_file, tgt_file, imgdiff)
          except ValueError as e:
            raise ValueError(
                "Failed to generate diff for %s: src=%s, tgt=%s: %s" % (
                    xf.tgt_name, xf.src_ranges, xf.tgt_ranges, e.message))

          with lock:
            patches[patch_index] = (xf_index, patch)
            if sys.stdout.isatty():
              progress = len(patches) * 100 / diff_total
              # '\033[K' is to clear to EOL.
              print(' [%d%%] %s\033[K' % (progress, xf.tgt_name), end='\r')
              sys.stdout.flush()

      threads = [threading.Thread(target=diff_worker)
                 for _ in range(self.threads)]
@@ -854,16 +874,29 @@ class BlockImageDiff(object):
        th.start()
      while threads:
        threads.pop().join()

      if sys.stdout.isatty():
        print('\n')
    else:
      patches = []

    p = 0
    with open(prefix + ".patch.dat", "wb") as patch_f:
      for patch, xf in patches:
        xf.patch_start = p
    offset = 0
    with open(prefix + ".patch.dat", "wb") as patch_fd:
      for index, patch in patches:
        xf = self.transfers[index]
        xf.patch_len = len(patch)
        patch_f.write(patch)
        p += len(patch)
        xf.patch_start = offset
        offset += xf.patch_len
        patch_fd.write(patch)

        if common.OPTIONS.verbose:
          tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
          print("%10d %10d (%6.2f%%) %7s %s %s %s" % (
                xf.patch_len, tgt_size, xf.patch_len * 100.0 / tgt_size,
                xf.style,
                xf.tgt_name if xf.tgt_name == xf.src_name else (
                    xf.tgt_name + " (from " + xf.src_name + ")"),
                xf.tgt_ranges, xf.src_ranges))

  def AssertSequenceGood(self):
    # Simulate the sequences of transfers we will output, and check that:
@@ -1211,7 +1244,9 @@ class BlockImageDiff(object):
      # Change nothing for small files.
      if (tgt_ranges.size() <= max_blocks_per_transfer and
          src_ranges.size() <= max_blocks_per_transfer):
        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
                 style, by_id)
        return

      while (tgt_ranges.size() > max_blocks_per_transfer and
@@ -1221,8 +1256,9 @@ class BlockImageDiff(object):
        tgt_first = tgt_ranges.first(max_blocks_per_transfer)
        src_first = src_ranges.first(max_blocks_per_transfer)

        Transfer(tgt_split_name, src_split_name, tgt_first, src_first, style,
                 by_id)
        Transfer(tgt_split_name, src_split_name, tgt_first, src_first,
                 self.tgt.RangeSha1(tgt_first), self.src.RangeSha1(src_first),
                 style, by_id)

        tgt_ranges = tgt_ranges.subtract(tgt_first)
        src_ranges = src_ranges.subtract(src_first)
@@ -1234,8 +1270,9 @@ class BlockImageDiff(object):
        assert tgt_ranges.size() and src_ranges.size()
        tgt_split_name = "%s-%d" % (tgt_name, pieces)
        src_split_name = "%s-%d" % (src_name, pieces)
        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges, style,
                 by_id)
        Transfer(tgt_split_name, src_split_name, tgt_ranges, src_ranges,
                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
                 style, by_id)

    def AddTransfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id,
                    split=False):
@@ -1244,7 +1281,9 @@ class BlockImageDiff(object):
      # We specialize diff transfers only (which covers bsdiff/imgdiff/move);
      # otherwise add the Transfer() as is.
      if style != "diff" or not split:
        Transfer(tgt_name, src_name, tgt_ranges, src_ranges, style, by_id)
        Transfer(tgt_name, src_name, tgt_ranges, src_ranges,
                 self.tgt.RangeSha1(tgt_ranges), self.src.RangeSha1(src_ranges),
                 style, by_id)
        return

      # Handle .odex files specially to analyze the block-wise difference. If
+11 −4
Original line number Diff line number Diff line
@@ -144,6 +144,12 @@ class SparseImage(object):
    f.seek(16, os.SEEK_SET)
    f.write(struct.pack("<2I", self.total_blocks, self.total_chunks))

  def RangeSha1(self, ranges):
    h = sha1()
    for data in self._GetRangeData(ranges):
      h.update(data)
    return h.hexdigest()

  def ReadRangeSet(self, ranges):
    return [d for d in self._GetRangeData(ranges)]

@@ -155,10 +161,11 @@ class SparseImage(object):
    ranges = self.care_map
    if not include_clobbered_blocks:
      ranges = ranges.subtract(self.clobbered_blocks)
    h = sha1()
    for d in self._GetRangeData(ranges):
      h.update(d)
    return h.hexdigest()
    return self.RangeSha1(ranges)

  def WriteRangeDataToFd(self, ranges, fd):
    for data in self._GetRangeData(ranges):
      fd.write(data)

  def _GetRangeData(self, ranges):
    """Generator that produces all the image data in 'ranges'.  The