use core:lang;
use core:asm;
use lang:asm;
use lang:bs:macro;

/**
 * A pointer- or reference type in the C++ implementation.
 *
 * Represented as a pointer to the start of the object followed by a an integer offset. This lets us
 * check if a pointer dereference would be in range, and makes the GC happy.
 *
 * Can also act as a reference.
 *
 * Note: All helpers (and the actual types) are in the util package since Progvis needs to reload
 * the package that contains the actual definitions in order to get their machine code. This way we
 * avoid reloading the entire C frontend everytime (parsing is expensive).
 */
class PtrType extends Type {
	init(Value inside, Str name, Bool isConst, Bool isRef, Bool rvalRef) {
		init(name, [inside], TypeFlags:typeValue) { isRef = isRef; rvalRef = rvalRef; isConst = isConst; }
	}

	// Is this a reference?
	Bool isRef;

	// R-value reference?
	Bool rvalRef;

	// Is the thing we're pointing to 'const'?
	Bool isConst;

	// Is this a pointer?
	Bool isPtr() {
		!isRef & !rvalRef;
	}

	// Get the type inside.
	Type? inside() {
		params[0].type;
	}

	// Load members.
	Bool loadAll() : override {
		// Note: We make assumptions regarding the type in generated code. Don't alter the order of
		// these!
		add(MemberVar("base", Value(named{core:unsafe:RawPtr}), this));
		add(MemberVar("offset", Value(named{core:Nat}), this));

		// Default ctors.
		// add(TypeDefaultCtor(this));
		add(TypeCopyCtor(this));
		add(TypeAssign(this));

		Value val(this, false);
		Value ref(this, true);
		Value int(named{Int});

		// Version of the default ctor that writes a zero to our struct, so that we can see the
		// initialization.
		addFn(Value(), "__init", [ref], defCtor());

		// Compare pointers.
		addFn(named{Bool}, "==", [ref, ref], named{ptr:pointerEq<unsafe:RawPtr, unsafe:RawPtr>});
		addFn(named{Bool}, "!=", [ref, ref], named{ptr:pointerNeq<unsafe:RawPtr, unsafe:RawPtr>});
		addFn(named{Bool}, "<", [ref, ref], named{ptr:pointerLt<unsafe:RawPtr, unsafe:RawPtr>});
		addFn(named{Bool}, ">", [ref, ref], named{ptr:pointerGt<unsafe:RawPtr, unsafe:RawPtr>});
		addFn(named{Bool}, "<=", [ref, ref], named{ptr:pointerLte<unsafe:RawPtr, unsafe:RawPtr>});
		addFn(named{Bool}, ">=", [ref, ref], named{ptr:pointerGte<unsafe:RawPtr, unsafe:RawPtr>});

		// Pointer arithmetic.
		if (t = inside) {
			Size sz = t.size.aligned;
			addFn(val, "+", [ref, int], ptrAdd(Offset(sz)));
			addFn(val, "-", [ref, int], ptrAdd(Offset(sz)));
			addFn(ref, "+=", [ref, int], ptrInc(Offset(sz)));
			addFn(ref, "-=", [ref, int], ptrInc(-Offset(sz)));
			addFn(ref, "++*", [ref], ptrPrefixInc(Offset(sz)));
			addFn(val, "*++", [ref], ptrPostfixInc(Offset(sz)));
			addFn(ref, "--*", [ref], ptrPrefixInc(-Offset(sz)));
			addFn(val, "*--", [ref], ptrPostfixInc(-Offset(sz)));
			addFn(int, "-", [ref, ref], ptrDiff(Offset(sz)));
		}

		// Allocate arrays from Storm. Useful when implementing the standard library. These will be marked as heap allocations.
		if (type = inside()) {
			Function f(val, "allocArray", [Value(named{Nat})]);
			f.setCode(DynamicCode(allocArrayFn(type)));
			f.make(FnFlags:static);
			add(f);

			Function d(Value(), "deepCopy", [thisPtr(this), named{CloneEnv}]);
			d.setCode(DynamicCode(deepCopyFn(type)));
			add(d);
		}


		// TODO: Add suitable members!

		super:loadAll();
	}

	// Add a function ptr.
	private void addFn(Value result, Str name, Value[] params, Function fn) {
		Function f(result, name, params);
		f.setCode(DelegatedCode(fn.ref));
		add(f);
	}

	private void addFn(Value result, Str name, Value[] params, Listing l) {
		Function f(result, name, params);
		f.setCode(DynamicCode(l));
		add(f);
	}

	// Generate the default ctor. We want to write to the int so that it registers as an
	// initialization.
	private Listing defCtor() : static {
		Listing l(true, ptrDesc);
		Var me = l.createParam(ptrDesc);
		l << prolog();
		l << fnParam(ptrDesc, me);
		l << fnCall(named{core:unsafe:RawPtr:__init<core:unsafe:RawPtr>}.ref, true);
		l << mov(ptrA, me);
		l << mov(intRel(ptrA, Offset(sPtr)), intConst(0));
		l << fnRet(ptrA);
		l;
	}

	// Generate += / -= operator.
	private Listing ptrInc(Offset offset) : static {
		Listing l(true, ptrDesc);

		Var me = l.createParam(ptrDesc);
		Var delta = l.createParam(intDesc);

		l << prolog();
		l << mov(ptrA, me);
		l << mov(ebx, delta);
		l << mul(ebx, intConst(offset));
		l << add(intRel(ptrA, Offset(sPtr)), ebx);
		l << fnRet(ptrA);

		l;
	}

	// Generate + and - operator. (Note: We don't currently support 3 + <ptr>)
	private Listing ptrAdd(Offset offset) {
		Listing l(true, Value(this).desc);

		Var me = l.createParam(ptrDesc);
		Var delta = l.createParam(intDesc);
		Var res = l.createVar(l.root, size);

		l << prolog();
		l << mov(ptrA, me);
		l << mov(ptrRel(res), ptrRel(ptrA));
		l << mov(intRel(res, Offset(sPtr)), intRel(ptrA, Offset(sPtr)));

		l << mov(ebx, delta);
		l << mul(ebx, intConst(offset));
		l << add(intRel(res, Offset(sPtr)), ebx);
		l << fnRet(res);

		l;
	}

	// Prefix ++ and --.
	private Listing ptrPrefixInc(Offset offset) {
		Listing l(true, ptrDesc);

		Var me = l.createParam(ptrDesc);

		l << prolog();
		l << mov(ptrA, me);
		l << add(intRel(ptrA, Offset(sPtr)), intConst(offset));
		l << fnRet(ptrA);

		l;
	}

	// Postfix ++ and --.
	private Listing ptrPostfixInc(Offset offset) {
		Listing l(true, Value(this).desc);

		Var me = l.createParam(ptrDesc);
		Var res = l.createVar(l.root, size);

		l << prolog();
		l << mov(ptrA, me);

		// Make a copy.
		l << mov(ptrRel(res), ptrRel(ptrA));
		l << mov(intRel(res, Offset(sPtr)), intRel(ptrA, Offset(sPtr)));

		l << add(intRel(ptrA, Offset(sPtr)), intConst(offset));
		l << fnRet(res);

		l;
	}

	// Difference between two pointers.
	private Listing ptrDiff(Offset offset) : static {
		Listing l(true, intDesc);

		Var me = l.createParam(ptrDesc);
		Var o = l.createParam(ptrDesc);

		l << prolog();

		// Check if they are from the same allocation.
		l << fnParam(ptrDesc, me);
		l << fnParam(ptrDesc, o);
		l << fnCall(named{ptr:assumeSameAlloc<unsafe:RawPtr, unsafe:RawPtr>}.ref, false);

		l << mov(ptrA, me);
		l << mov(ptrB, o);
		l << mov(eax, intRel(ptrA, Offset(sPtr)));
		l << mov(ebx, intRel(ptrB, Offset(sPtr)));
		l << sub(eax, ebx);
		l << idiv(eax, intConst(offset));
		l << fnRet(eax);

		l;
	}

	private Listing allocArrayFn(Type inside) {
		Listing l(false, this.typeDesc);

		Var res = l.createVar(l.root, this.size);
		Var param = l.createParam(intDesc);

		l << prolog();

		l << ucast(ptrA, param);
		l << fnParam(ptrDesc, inside.typeRef);
		l << fnParam(ptrDesc, ptrA);
		l << fnCall(ref(BuiltIn:allocArray), false, ptrDesc, ptrA);
		l << mov(ptrRel(res, Offset()), ptrA);
		l << mov(intRel(res, Offset(sPtr)), natConst(sPtr * 2));

		Nat mask = AllocFlags:arrayAlloc.v | AllocFlags:heapAlloc.v;
		l << or(param, natConst(mask));
		l << mov(intRel(ptrA, Offset(sPtr)), param);

		l << fnRet(res);

		l;
	}

	private Listing deepCopyFn(Type inside) {
		Listing l(true, voidDesc);

		var pDesc = ptrDesc;

		Var thisParam = l.createParam(pDesc);
		Var envParam = l.createParam(pDesc);

		l << prolog();

		// Get the allocation and see if it has
		Var alloc = l.createVar(l.root, sPtr);
		l << mov(ptrA, thisParam);
		l << mov(alloc, ptrRel(ptrA));

		Label done = l.label();
		l << cmp(alloc, ptrConst(Offset()));
		l << jmp(done, CondFlag:ifEqual);

		Var copy = l.createVar(l.root, sPtr);
		l << fnParam(pDesc, envParam);
		l << fnParam(pDesc, alloc);
		l << fnCall(ref(BuiltIn:cloneEnvGet), false, pDesc, copy);

		// Did we get something?
		l << cmp(copy, ptrConst(Offset()));
		l << jmp(done, CondFlag:ifNotEqual);

		// No, we need to clone things ourselves.

		// Read the size of the allocation and allocate a copy!
		l << mov(ptrA, alloc);
		l << mov(ptrA, ptrRel(ptrA));
		l << fnParam(ptrDesc, inside.typeRef);
		l << fnParam(ptrDesc, ptrA);
		l << fnCall(ref(BuiltIn:allocArray), false, pDesc, copy);

		// Copy mask/filled.
		l << mov(ptrA, alloc);
		l << mov(ptrC, copy);
		l << mov(ptrRel(ptrC, Offset(sPtr)), ptrRel(ptrA, Offset(sPtr)));

		// Copy all elements.
		Label loopHead = l.label();
		Label loopTail = l.label();
		Var id = l.createVar(l.root, sPtr); // initialized to zero
		Var allocPos = l.createVar(l.root, sPtr);
		Var copyPos = l.createVar(l.root, sPtr);

		l << mov(ptrA, alloc);
		l << lea(allocPos, ptrRel(ptrA, Offset(sPtr * 2)));
		l << mov(ptrA, copy);
		l << lea(copyPos, ptrRel(ptrA, Offset(sPtr * 2)));

		l << loopHead;
		l << mov(ptrA, alloc);
		l << cmp(id, ptrRel(ptrA));
		l << jmp(loopTail, CondFlag:ifAboveEqual);

		if (Value(inside).isAsmType()) {
			l << mov(ptrA, allocPos);
			l << mov(ptrC, copyPos);
			var size = Value(inside).size;
			l << mov(xRel(size, ptrC), xRel(size, ptrA));
		} else if (copyCtor = inside.copyCtor) {
			l << fnParam(pDesc, copyPos);
			l << fnParam(pDesc, allocPos);
			l << fnCall(copyCtor.ref, false);
		} else {
			// Note: this is not entirely platform independent:
			Offset offset;
			Nat ptrSize = sPtr.current;
			Nat totalSize = inside.size.current;
			l << mov(ptrA, allocPos);
			l << mov(ptrC, copyPos);
			while (offset.current.nat + ptrSize <= totalSize) {
				l << mov(ptrRel(ptrC, offset), ptrRel(ptrA, offset));
				offset += sPtr;
			}
			while (offset.current.nat + 1 <= totalSize) {
				l << mov(byteRel(ptrC, offset), byteRel(ptrA, offset));
				offset += sByte;
			}
		}

		if (deepCopy = inside.deepCopyFn) {
			l << fnParam(pDesc, copyPos);
			l << fnParam(pDesc, envParam);
			l << fnCall(deepCopy.ref, true);
		}

		Size alignedSize = inside.size.aligned;
		l << add(allocPos, ptrConst(alignedSize));
		l << add(copyPos, ptrConst(alignedSize));
		l << add(id, ptrConst(1));
		l << jmp(loopHead);

		l << loopTail;
		// Save it to the CloneEnv.
		l << fnParam(pDesc, envParam);
		l << fnParam(pDesc, alloc);
		l << fnParam(pDesc, copy);
		l << fnCall(ref(BuiltIn:cloneEnvPut), false);

		// Store the updated allocation and we're done.
		l << done;
		l << mov(ptrC, thisParam);
		l << mov(ptrRel(ptrC), copy);

		l << fnRet();
		l;
	}

	// Nicer to string in error messages etc.
	void toS(StrBuf to) : override {
		if (isConst)
			to << "const ";
		to << params[0];
		if (isRef)
			to << "&";
		else if (rvalRef)
			to << "&&";
		else
			to << "*";
	}

	// Also for the identifier.
	Str identifier() : override {
		toS();
	}
}

// Wrap things inside a pointer or a reference.
Value wrapPtr(Value val) {
	unless (t = (named{ptr}).find(SimplePart("Ptr", [val.asRef(false)]), Scope()) as Type)
		throw InternalError("Could not find the pointer type for ${val}");
	Value(t);
}

Value wrapConstPtr(Value val) {
	unless (t = (named{ptr}).find(SimplePart("ConstPtr", [val.asRef(false)]), Scope()) as Type)
		throw InternalError("Could not find the pointer type for ${val}");
	Value(t);
}

Value wrapRef(Value val) {
	unless (t = (named{ptr}).find(SimplePart("Ref", [val.asRef(false)]), Scope()) as Type)
		throw InternalError("Could not find the pointer type for ${val}");
	Value(t);
}

Value wrapConstRef(Value val) {
	unless (t = (named{ptr}).find(SimplePart("ConstRef", [val.asRef(false)]), Scope()) as Type)
		throw InternalError("Could not find the pointer type for ${val}");
	Value(t);
}

Value wrapRRef(Value val) {
	unless (t = (named{ptr}).find(SimplePart("RRef", [val.asRef(false)]), Scope()) as Type)
		throw InternalError("Could not find the pointer type for ${val}");
	Value(t);
}

Value wrapConstRRef(Value val) {
	unless (t = (named{ptr}).find(SimplePart("ConstRRef", [val.asRef(false)]), Scope()) as Type)
		throw InternalError("Could not find the pointer type for ${val}");
	Value(t);
}

// Unwrap pointers and references.
Value unwrapPtr(Value val) {
	if (t = val.type as PtrType) {
		if (!t.isRef)
			return t.params[0];
	}
	val;
}

Value unwrapRef(Value val) {
	if (t = val.type as PtrType) {
		if (t.isRef)
			return t.params[0];
	}
	val;
}

Value unwrapPtrOrRef(Value val) {
	if (t = val.type as PtrType) {
		return t.params[0];
	}
	val;
}

// Is it a ptr or ref?
Bool isCppPtr(Value val) {
	if (t = val.type as PtrType) {
		return t.isPtr;
	}
	false;
}

Bool isCppRef(Value val) {
	if (t = val.type as PtrType) {
		return t.isRef;
	}
	false;
}

// Unwrap a reference. Returns 'null' if not a reference.
Type? isCppRef(Type t) {
	if (t as PtrType)
		if (t.isRef)
			return t.inside();
	null;
}

Type? isCppRRef(Type t) {
	if (t as PtrType)
		if (t.isRef & t.rvalRef)
			return t.inside();
	null;
}

Bool isCppConst(Type t) {
	if (t as PtrType)
		return t.isConst;
	false;
}
