#!/usr/bin/python3

# Copyright 2020-2021 Yury Gribov
# 
# Use of this source code is governed by MIT license that can be
# found in the LICENSE.txt file.

"""
Wrapper for ld which automatically calls implib-gen.
To enable, add it to PATH.
Flags are passed via env variable IMPLIBSO_LD_OPTIONS.
"""

import sys
import os
import os.path
import re
import subprocess
import argparse
import tempfile
import atexit
import shutil

me = os.path.basename(__file__)
v = 0

def warn(msg):
  """
  Print nicely formatted warning message.
  """
  sys.stderr.write(f'{me}: warning: {msg}\n')

def error(msg):
  """
  Print nicely formatted error message and exit.
  """
  sys.stderr.write(f'{me}: error: {msg}\n')
  sys.exit(1)

def note(msg):
  """
  Print nicely formatted message.
  """
  sys.stderr.write(f'{me}: {msg}\n')

def run(cmd):
  """
  Simple wrapper for subprocess.Popen,
  """
  if isinstance(cmd, str):
    cmd = cmd.split(' ')
  if v > 0:
    note(f"running {cmd}...")
  with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
    out, err = p.communicate()
    out = out.decode('utf-8')
    err = err.decode('utf-8')
    if p.returncode != 0:
      sys.stdout.write(out)
      sys.stderr.write(err)
      sys.exit(p.returncode)
    return p.returncode, out, err

def which(exe, alll=False):
  """
  Analog if which(1).
  """
  paths = os.environ['PATH'].split(os.pathsep)
  exes = []
  for p in paths:
    f = os.path.join(p, exe)
    if os.path.isfile(f):
      exes.append(f)
  first = exes[0] if exes else None
  return exes if alll else first

def is_system_library(name):
  """
  Is library a basic one (and thus should not be wrapped)?
  """
  return name in ['c', 'm', 'dl', 'rt', 'pthread', 'gcc', 'gcc_s']

def main():
  "Entry point"

  parser = argparse.ArgumentParser(description="""\
Ld wrapper to simplify application of implib-gen.

Flags can be specified directly or via IMPLIBSO_LD_OPTIONS environment variable.
""")
  parser.add_argument('--verbose', '-v',
                      help="Print diagnostic info.",
                      action='count',
                      default=0)
  parser.add_argument('--outdir', '-o',
                      help="Directory for temp files.",
                      default='')
  parser.add_argument('--wrap-libs',
                      help="Wrap only libraries in comma-separated list.",
                      default='')
  parser.add_argument('--no-wrap-libs',
                      help="Do not wrap only libraries in comma-separated list.",
                      default='')

  env_default = '--help' if len(sys.argv) <= 1 else ''
  env_opts = os.environ.get('IMPLIBSO_LD_OPTIONS', env_default)
  env_argv = env_opts.split(' ') if env_opts else []
  env_args = parser.parse_args(env_argv)

  global v  # pylint: disable=global-statement
  v = env_args.verbose

  if env_args.outdir:
    tmpd = env_args.outdir
  else:
    tmpd = tempfile.mkdtemp('implib-gen-ld')
    atexit.register(lambda: shutil.rmtree(tmpd))

  wrap_libs = set()
  no_wrap_libs = set()
  for obj, param in [(wrap_libs, env_args.wrap_libs),
                     (no_wrap_libs, env_args.no_wrap_libs)]:
    for lib in param.split(','):
      if lib:
        lib = re.sub(r'^(lib)?(.*)(\.so).*', r'\1', lib)
        obj.add(lib)

  args = []
  out_filename = 'a.out'
  for i in range(len(sys.argv)):
    arg = sys.argv[i]
    if arg == '-o':
      i += 1
      out_filename = sys.argv[i]
    elif arg.startswith('-o'):
      out_filename = arg[2:]

  all_lds = which(sys.argv[0], alll=True)
  if not all_lds:
    error("no ld executables in PATH")
  elif len(all_lds) == 1:
    error("real ld executable not found in PATH")
  real_ld = all_lds[1]
  this_ld = os.path.realpath(__file__)
  if os.path.realpath(real_ld) == os.path.realpath(this_ld):
    error("ld wrapper is not the first executable in PATH")
  sys.argv[0] = real_ld

  # Generate normal output first
  rc, out, err = run(sys.argv)
  if rc != 0:
    sys.stdout.write(out)
    sys.stderr.write(err)
    sys.exit(rc)

  # Analyze ldd output to see which runtime libs were linked
  rc, out, err = run(f'ldd {out_filename}')
  os.unlink(out_filename)  # Remove output in case we fail later
  if rc != 0:
    error(f"ldd failed: {err}")
  sys.stderr.write(err)

  class WrapperInfo:
    "Holds info about replaceable library"
    def __init__(self, name, path):
      self.name = name
      self.path = path
      self.wrappers = []
      self.replaced = False

  libs = {}
  for l in out.split('\n'):
    l = l.strip()
    if not l:
      continue
    if re.search(r'linux-vdso|ld-linux', l):
      if v > 0:
        note(f"skipping system library: {l}")
      continue
    m = re.search(r'^lib(\S*)\.so(\.[0-9]+)? => (\S*)', l)
    if m is None:
      warn(f"failed to parse ldd output: {l}")
      continue
    name = m.group(1)
    path = m.group(3)
    if is_system_library(name):
      if v > 0:
        note(f"skipping system library: {l}")
      continue
    if wrap_libs and name not in wrap_libs:
      if v > 0:
        note(f"skipping library not in used-defined whitelist: {l}")
      continue
    if no_wrap_libs and name in no_wrap_libs:
      if v > 0:
        note(f"skipping library in used-defined blacklist: {l}")
      continue
    if v > 0:
      note(f"wrappable library: {name} ({path})")
    libs[name] = WrapperInfo(name, path)

  # Compile wrappers
  for _, info in sorted(libs.items()):
    rc, out, err = run(f'implib-gen.py --outdir {tmpd} {info.path}')
    if rc != 0:
      error(f"implib-gen failed: {err}")
    sys.stderr.write(err)
    for l in out.split('\n'):
      l = l.strip()
      if not l:
        continue
      m = re.match(r'^Generating (.*)\.\.\.', l)
      if m is not None:
        f = m.group(1)
        o, _ = os.path.splitext(f)
        o += '.o'
        if v > 0:
          note(f"compiling wrapper for {f} in {o}")
        rc, out, err = run(f'gcc -Wall -Wextra -O2 -c -o {o} {f}')
        if rc != 0:
          error(f"implib-gen failed: {err}")
        sys.stderr.write(err)
        info.wrappers.append(o)

  # Relink with wrapped code
  args = []
  changed = False
  for i in range(len(sys.argv)):
    arg = sys.argv[i]
    if arg == '-l':
      i += 1
      name = sys.argv[i]
    elif arg.startswith('-l'):
      name = arg[2:]
    else:
      args.append(arg)
      continue
    if is_system_library(name):
      args.append(arg)
      continue
    info = libs.get(name)
    if info is None:
      args.append(arg)
      continue
    for o in info.wrappers:
      changed = True
      args.append(o)
      info.replaced = True
  if changed:
    args.append('-ldl')

  for name, info in sorted(libs.items()):
    if not info.replaced:
      warn("failed to replace library %s")

  rc, out, err = run(args)
  sys.stdout.write(out)
  sys.stderr.write(err)
  sys.exit(rc)

if __name__ == '__main__':
  main()
