#!/usr/bin/env python3.6

'''
Copyright (c) 2019 Martin Storsjo

Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.

The software is provided "as is" and the author(s) disclaim(s) all warranties
with regard to this software including all implied warranties of
merchantability and fitness. In no event shall the author(s) be liable for
any special, direct, indirect, or consequential damages or any damages
whatsoever resulting from loss of use, data or profits, whether in an
action of contract, negligence or other tortious action, arising out of
or in connection with the use or performance of this software.

Modified by chatterjee arnab

chmod +x VC.py
./VC.py --accept-license --ignores ASAN CodeSense CoreEditor CoreIDE Debugger DIA.SDK Editors IntelliTrace LiveShareApi OneCore SccCodeLenses UnitTest TestTools WebSiteProject WebTools Windows10SDK
'''

from argparse import ArgumentParser
from functools import cmp_to_key
from hashlib import sha256
from json import loads
from os import access, F_OK, listdir, makedirs,  remove
from os.path import abspath, isdir, isfile, join
from shutil import move, rmtree
from sys import exit
from urllib.parse import unquote
from urllib.request import urlopen, urlretrieve
from zipfile import ZipFile

def getArgsParser():
	parser = ArgumentParser(description="Download and install Visual Studio.")
	parser.add_argument("--accept-license", action="store_const", const="y", default="n", help="Don't prompt for accepting the license.")
	parser.add_argument("--chip", metavar="arch", default="X86", help="ARM/X64/X86 (defaults to X86).")
	parser.add_argument("--dest", metavar="dir", default="...", help="Directory to install into (defaults to current directory).")
	parser.add_argument("--ignores", metavar="component", nargs="+", help="Packages to skip.")
	parser.add_argument("--include-optional", action="store_true", help="Include all optional dependencies.")
	parser.add_argument("--list-components", action="store_true", help="List available components.")
	parser.add_argument("--list-packages", action="store_true", help="List all individual packages, regardless of type.")
	parser.add_argument("--list-workloads", action="store_true", help="List high level workloads.")
	parser.add_argument("--major", metavar="version", default=16, help="The major version to download (defaults to 16).")
	parser.add_argument("--manifest", metavar="manifest", help="A predownloaded manifest file.")
	parser.add_argument("--only-download", action="store_true", help="Stop after downloading package files.")
	parser.add_argument("--packages", metavar="package", default=["Microsoft.VisualStudio.Workload.VCTools"], nargs="+", help="Packages to install. If omitted, installs the default command line tools.")
	parser.add_argument("--preview", action="store_const", const="pre", default="release", dest="type", help="Download the preview version instead of the release version.")
	parser.add_argument("--print-deps-tree", action="store_true", help="Print a tree of resolved dependencies for the given selection.")
	parser.add_argument("--print-reverse-deps", action="store_true", help="Print a tree of packages that depend on the given selection.")
	parser.add_argument("--print-selection", action="store_true", help="Print a list of the individual packages that are selected to be installed.")
	parser.add_argument("--print-version", action="store_true", help="Stop after fetching the manifest.")
	parser.add_argument("--save-manifest", action="store_true", help="Store the downloaded manifest to a file.")
	parser.add_argument("--skip-recommended", action="store_true", help="Skip all recommended dependencies.")
	return parser

def lowercaseIgnores(args):
	ignores = args.ignores
	args.ignores = set()
	if ignores is None: return
	for ignore in ignores: args.ignores.add(ignore.lower())

def getManifest(args):
	if args.manifest is None:
		url = "https://aka.ms/vs/%s/%s/channel" % (args.major, args.type)
		print("Fetching %s" % (url))
		manifest = loads(urlopen(url).read())
		print("Got toplevel manifest for %s" % (manifest["info"]["productDisplayVersion"]))
		for channelItem in manifest["channelItems"]:
			ctype = channelItem.get("type")
			if ctype is not None and ctype == "Manifest": args.manifest = channelItem["payloads"][0]["url"]
		if args.manifest is None:
			print("Unable to find an intaller manifest.")
			exit(1)
	if not args.manifest.startswith("http"): args.manifest = "file:" + args.manifest
	manifestdata = urlopen(args.manifest).read()
	manifest = loads(manifestdata)
	print("Loaded installer manifest for %s" % (manifest["info"]["productDisplayVersion"]))
	if args.save_manifest:
		filename = "%s.manifest" % (manifest["info"]["productDisplayVersion"])
		if isfile(filename):
			oldfile = open(filename, "r").read()
			if oldfile != manifestdata: print("Old saved manifest in \"%s\" differs from newly downloaded one, not overwriting." % (filename))
			else: print("Old saved manifest in \"%s\" is still current" % (filename))
		else:
			f = open(filename, "w")
			f.write(manifestdata)
			f.close()
			print("Saved installer manifest to \"%s\"" % (filename))
	return manifest

def prioritizePackage(a, b):
	ax64 = a.get("chip")
	bx64 = b.get("chip")
	if ax64 is not None and bx64 is not None:
		ax64 = ax64.lower() == "x64"
		bx64 = bx64.lower() == "x64"
		if ax64 and not bx64: return -1
		if bx64 and not ax64: return 1
	aeng = a.get("language")
	beng = b.get("language")
	if aeng is not None and beng is not None:
		aeng = aeng.lower().startswith("en-")
		beng = beng.lower().startswith("en-")
		if aeng and not beng: return -1
		if beng and not aeng: return 1
	return 0

def getPackages(manifest):
	packages = {}
	for package in manifest["packages"]:
		id = package["id"].lower()
		if id not in packages: packages[id] = []
		packages[id].append(package)
	for key in packages: packages[key] = sorted(packages[key], key=cmp_to_key(prioritizePackage))
	return packages

def listPackageType(packages, ptype):
	if ptype is not None: ptype = ptype.lower()
	ids = []
	for package in packages.values():
		package = package[0]
		if ptype is None: ids.append(package["id"])
		else:
			t = package.get("type")
			if t is not None and t.lower() == ptype: ids.append(package["id"])
	for id in sorted(ids): print(id)

def findPackage(packages, id, chip, warn=True):
	origid = id
	id = id.lower()
	if not id in packages:
		if warn: print("Warning: %s not found" % (origid))
		return None
	candidates = packages[id]
	if chip is not None:
		chip = chip.lower()
		for candidate in candidates:
			c = candidate.get("chip")
			if c is not None and c.lower() == chip: return candidate
	return candidates[0]

def printDepends(packages, target, dtype, chip, indent, args):
	if chip is not None: chipstr = " (" + chip + ")"
	else: chipstr = ""
	if dtype != "": dtypestr = " (" + dtype + ")"
	else: dtypestr = ""
	ignorestr = ""
	ignore = False
	for ign in args.ignores:
		if ign in target.lower():
			ignorestr = " (Ignored)"
			ignore = True
			break
	print(indent + target + chipstr + dtypestr + ignorestr)
	if dtype == "Optional" and not args.include_optional or dtype == "Recommended" and args.skip_recommended or ignore: return
	package = findPackage(packages, target, chip)
	if package is None: return
	dependencies = package.get("dependencies")
	if dependencies is None: return
	for target, dependency in dependencies.items():
		if isinstance(dependency, dict):
			chip = dependency.get("chip")
			dtype = dependency.get("type")
			if dtype is None: dtype = ""
		else:
			chip = None
			dtype = ""
		printDepends(packages, target, dtype, chip, indent + " ", args)

def printReverseDepends(packages, target, dtype, indent, args):
	if dtype != "": dtypestr = " (" + dtype + ")"
	else: dtypestr = ""
	print(indent + target + dtypestr)
	if dtype == "Optional" and not args.include_optional: return
	if dtype == "Recommended" and args.skip_recommended: return
	for package in packages:
		package = package[0]
		if not isinstance(package, dict): continue
		dependencies = package.get("dependencies")
		if dependencies is None: continue
		for key, dependency in dependencies.items():
			if key.lower() != target.lower(): continue
			if isinstance(dependency, dict):
				dtype = dependency.get("type")
				if dtype is None: dtype = ""
			else: dtype = ""
			printReverseDepends(packages, package["id"], dtype, indent + " ", args)

def getPackageKey(package):
	chip = package.get("chip")
	if chip is not None and chip.upper() != args.chip: return None
	packagekey = package["id"]
	pk = packagekey.upper()
	arch = ["ARM", "X64", "X86"]
	arch.remove(args.chip)
	if (arch[0] in pk or arch[1] in pk) and args.chip not in pk: return None
	version = package.get("version")
	if version is not None: packagekey += "-" + version
	if chip is not None: packagekey += "-" + chip
	return packagekey

def aggregateDepends(packages, included, target, chip, args):
	for ignore in args.ignores:
		if ignore in target.lower(): return []
	package = findPackage(packages, target, chip)
	if package is None: return
	packagekey = getPackageKey(package)
	if packagekey is None or packagekey in included: return []
	ret = [package]
	included[packagekey] = True
	dependencies = package.get("dependencies")
	if dependencies is None: return ret
	for key, dependency in dependencies.items():
		if isinstance(dependency, dict):
			dtype = dependency.get("type")
			if dtype is not None:
				if dtype == "Optional" and not args.include_optional: continue
				if dtype == "Recommended" and args.skip_recommended: continue
			chip = dependency.get("chip")
		else: chip = None
		ret.extend(aggregateDepends(packages, included, key, chip, args))
	return ret

def getSelectedPackages(packages, args):
	ret = []
	included = {}
	for package in args.packages: ret.extend(aggregateDepends(packages, included, package, None, args))
	return ret

def sumDownloadSize(packages):
	total = 0
	for package in packages:
		payloads = package.get("payloads")
		if payloads is None: continue
		for payload in payloads:
			size = payload.get("size")
			if size is not None: total += size
	return total

def sumInstallSize(packages):
	total = 0
	for package in packages:
		installSizes = package.get("installSizes")
		if installSizes is None: continue
		for installSize in installSizes.values(): total += installSize
	return total

def formatSize(size):
	if size > 900*1024*1024: return "%.1f GB" % (size/(1024*1024*1024))
	if size > 900*1024: return "%.1f MB" % (size/(1024*1024))
	if size > 1024: return "%.1f KB" % (size/1024)
	return "%d bytes" % (size)

def printPackages(packages):
	for package in sorted(packages, key=lambda package: package["id"]):
		name = package["id"]
		ptype = package.get("type")
		if ptype is not None: name += " (" + ptype + ")"
		chip = package.get("chip")
		if chip is not None: name += " (" + chip + ")"
		language = package.get("language")
		if language is not None: name += " (" + language + ")"
		name += " " + formatSize(sumDownloadSize([package]))
		print(name)

def getPayloadName(payload):
	name = payload["fileName"]
	if "\\" in name: name = name.split("\\")[-1]
	if "/" in name: name = name.split("/")[-1]
	return name

def sha256File(file):
	sha256Hash = sha256()
	with open(file, "rb") as f:
		for byteBlock in iter(lambda: f.read(4096), b""): sha256Hash.update(byteBlock)
		return sha256Hash.hexdigest()

def downloadPackages(packages, dest, allowHashMismatch = False):
	downloaded = 0
	for package in packages:
		if not "payloads" in package: continue
		dir = join(dest, getPackageKey(package))
		makedirs(dir, exist_ok=True)
		for payload in package["payloads"]:
			payloadName = getPayloadName(payload)
			destname = join(dir, payloadName)
			fileid = join(getPackageKey(package), payloadName)
			if access(destname, F_OK):
				if "sha256" not in payload: continue
				if sha256File(destname).lower() == payload["sha256"].lower():
					print("Using existing file %s" % (fileid))
					continue
				print("Incorrect existing file %s, removing" % (fileid))
				remove(destname)
			size = payload.get("size")
			if size is None: size = 0
			print("Downloading %s (%s)" % (fileid, formatSize(size)))
			urlretrieve(payload["url"], destname)
			downloaded = downloaded + size
			if "sha256" in payload and sha256File(destname).lower() != payload["sha256"].lower():
				if not allowHashMismatch:
					print("Incorrect hash for downloaded file %s, aborting" % (fileid))
					exit(1)
				print("Warning: Incorrect hash for downloaded file %s" % (fileid))
	print("Downloaded %s in total" % (formatSize(downloaded)))

def extractFiltered(zip, dest):
	extract = join(dest, "extract")
	for f in zip.infolist():
		name = unquote(f.filename)
		if "/" in name:
			sep = name.rfind("/")
			dir = join(dest, name[0:sep])
			makedirs(dir, exist_ok=True)
		extracted = zip.extract(f, extract)
		move(extracted, join(dest, name))
	rmtree(extract)

def mergeTrees(src, dest):
	if not isdir(src): return
	if not isdir(dest):
		move(src, dest)
		return
	names = listdir(src)
	destnames = {}
	for name in listdir(dest): destnames[name.lower()] = name
	for name in names:
		srcname = join(src, name)
		destname = join(dest, name)
		if isdir(srcname):
			if isdir(destname): mergeTrees(srcname, destname)
			elif name.lower() in destnames: mergeTrees(srcname, join(dest, destnames[name.lower()]))
			else: move(srcname, destname)
		else: move(srcname, destname)

def unpackVsix(file, dest, listing):
	vsix = join(dest, "vsix")
	makedirs(vsix, exist_ok=True)
	with ZipFile(file, "r") as zip:
		extractFiltered(zip, vsix)
		with open(listing, "w") as f:
			for n in zip.namelist(): f.write(n + "\n")
	contents = join(vsix, "Contents")
	if access(contents, F_OK): mergeTrees(contents, dest)
	rmtree(vsix)

def unpackPackages(packages, dest):
	unpack = join(dest, "unpack")
	makedirs(unpack, exist_ok=True)
	for package in packages:
		ptype = package["type"]
		dir = join(dest, getPackageKey(package))
		if ptype == "Component" or ptype == "Workload" or ptype == "Group": continue
		if ptype == "Vsix":
			print("Unpacking " + package["id"])
			for payload in package["payloads"]: unpackVsix(join(dir, getPayloadName(payload)), unpack, join(unpack, getPackageKey(package) + "-listing.txt"))
		else: print("Skipping unpacking of " + package["id"] + " of type " + ptype)
	mergeTrees(join(unpack, "VC"), join(dest, "VC"))
	rmtree(unpack)

if __name__ == "__main__":
	parser = getArgsParser()
	args = parser.parse_args()
	if args.chip not in ("ARM", "X64", "X86"):
		print("Bad chip architecture.")
		exit(1)
	args.packages = set(args.packages)
	lowercaseIgnores(args)
	packages = getPackages(getManifest(args))
	if args.print_version: exit(0)
	if args.list_components:
		listPackageType(packages, "Component")
		exit(0)
	if args.list_packages:
		listPackageType(packages, None)
		exit(0)
	if args.list_workloads:
		listPackageType(packages, "Workload")
		exit(0)
	if args.print_deps_tree:
		for package in args.packages: printDepends(packages, package, "", None, "", args)
		exit(0)
	if args.print_reverse_deps:
		for package in args.packages: printReverseDepends(packages, package, "", "", args)
		exit(0)
	while args.accept_license not in ("y", "yes"):
		args.accept_license = input("Do you accept the license at " + findPackage(packages, "Microsoft.VisualStudio.Product.BuildTools", None)["localizedResources"][0]["license"] + " (Y/N)? ").lower()
		if args.accept_license in ("n", "no"): exit(0)
	packages = getSelectedPackages(packages, args)
	print("Selected %d packages, for a total download size of %s, install size of %s" % (len(packages), formatSize(sumDownloadSize(packages)), formatSize(sumInstallSize(packages))))
	if args.print_selection:
		printPackages(packages)
		exit(0)
	dest = abspath(args.dest)
	makedirs(dest, exist_ok=True)
	downloadPackages(packages, dest, allowHashMismatch=args.only_download)
	if not args.only_download: unpackPackages(packages, dest)

