#include <gu/out.h>
#include <gu/seq.h>
#include <gu/fun.h>
#include <gu/str.h>
#include <gu/assert.h>
#include <stdlib.h>
#ifdef __MINGW32__
#include <malloc.h>
#endif

struct GuSeq {
	size_t len;
	uint8_t data[0];
};

struct GuBuf {
	GuSeq* seq;
	size_t elem_size;
	size_t avail_len;
	GuFinalizer fin;
};

size_t
gu_buf_length(GuBuf* buf)
{
	return buf->seq->len;
}

size_t
gu_buf_avail(GuBuf* buf)
{
	return buf->avail_len;
}

static void
gu_buf_fini(GuFinalizer* fin)
{
	GuBuf* buf = gu_container(fin, GuBuf, fin);
	if (buf->avail_len > 0)
		gu_mem_buf_free(buf->seq);
}

GuBuf*
gu_make_buf(size_t elem_size, GuPool* pool)
{
	GuBuf* buf = gu_new(GuBuf, pool);
	buf->seq = gu_empty_seq();
	buf->elem_size = elem_size;
	buf->avail_len = 0;
	buf->fin.fn = gu_buf_fini;
	gu_pool_finally(pool, &buf->fin);
	return buf;
}

size_t
gu_seq_length(GuSeq* seq)
{
	return seq->len;
}

void*
gu_seq_data(GuSeq* seq)
{
	return seq->data;
}

static GuSeq gu_empty_seq_ = {0};

GuSeq*
gu_empty_seq() {
	return &gu_empty_seq_;
}

GuSeq*
gu_make_seq(size_t elem_size, size_t length, GuPool* pool)
{
	GuSeq* seq = gu_malloc(pool, sizeof(GuSeq) + elem_size * length);
	seq->len = length;
	return seq;
}

GuSeq*
gu_alloc_seq_(size_t elem_size, size_t length)
{
	if (length == 0)
		return gu_empty_seq();

	size_t real_size;
	GuSeq* seq = gu_mem_buf_alloc(sizeof(GuSeq) + elem_size * length, &real_size);
	seq->len = (real_size - sizeof(GuSeq)) / elem_size;
	return seq;
}

GuSeq*
gu_realloc_seq_(GuSeq* seq, size_t elem_size, size_t length)
{
	size_t real_size;
	GuSeq* new_seq = (seq == NULL || seq == gu_empty_seq()) ?
	   gu_mem_buf_alloc(sizeof(GuSeq) + elem_size * length, &real_size) :
	   gu_mem_buf_realloc(seq, sizeof(GuSeq) + elem_size * length, &real_size);
	new_seq->len = (real_size - sizeof(GuSeq)) / elem_size;
	return new_seq;
}

void
gu_seq_free(GuSeq* seq)
{
	if (seq == NULL || seq == gu_empty_seq())
		return;
	gu_mem_buf_free(seq);
}

static void
gu_buf_require(GuBuf* buf, size_t req_len)
{
	if (req_len <= buf->avail_len) {
		return;
	}

	size_t req_size = sizeof(GuSeq) + buf->elem_size * req_len;
	size_t real_size;
	
	if (buf->seq == NULL || buf->seq == gu_empty_seq())  {
		buf->seq = gu_mem_buf_alloc(req_size, &real_size);
		buf->seq->len = 0;
	} else {
		buf->seq = gu_mem_buf_realloc(buf->seq, req_size, &real_size);
	}

	buf->avail_len = (real_size - sizeof(GuSeq)) / buf->elem_size;
}

void*
gu_buf_data(GuBuf* buf)
{
	return &buf->seq->data;
}

GuSeq*
gu_buf_data_seq(GuBuf* buf)
{
	return buf->seq;
}

void*
gu_buf_extend_n(GuBuf* buf, size_t n_elems)
{
	size_t len = gu_buf_length(buf);
	size_t new_len = len + n_elems;
	gu_buf_require(buf, new_len);
	buf->seq->len = new_len;
	return &buf->seq->data[buf->elem_size * len];
}

void*
gu_buf_extend(GuBuf* buf)
{
	return gu_buf_extend_n(buf, 1);
}

void
gu_buf_push_n(GuBuf* buf, const void* data, size_t n_elems)
{
	void* p = gu_buf_extend_n(buf, n_elems);
	memcpy(p, data, buf->elem_size * n_elems);
}

const void*
gu_buf_trim_n(GuBuf* buf, size_t n_elems)
{
	gu_require(n_elems <= gu_buf_length(buf));
	size_t new_len = gu_buf_length(buf) - n_elems;
	buf->seq->len = new_len;
	return &buf->seq->data[buf->elem_size * new_len];
}

const void*
gu_buf_trim(GuBuf* buf)
{
	return gu_buf_trim_n(buf, 1);
}

void
gu_buf_flush(GuBuf* buf)
{
	buf->seq->len = 0;
}

void
gu_buf_pop_n(GuBuf* buf, size_t n_elems, void* data_out)
{
	const void* p = gu_buf_trim_n(buf, n_elems);
	memcpy(data_out, p, buf->elem_size * n_elems);
}

GuSeq*
gu_buf_freeze(GuBuf* buf, GuPool* pool)
{
	size_t len = gu_buf_length(buf);
	GuSeq* seq = gu_make_seq(buf->elem_size, len, pool);
	void* bufdata = gu_buf_data(buf);
	void* seqdata = gu_seq_data(seq);
	memcpy(seqdata, bufdata, buf->elem_size * len);
	return seq;
}

void*
gu_buf_insert(GuBuf* buf, size_t index)
{
	size_t len = buf->seq->len;
	gu_buf_require(buf, len + 1);

	uint8_t* target =
		buf->seq->data + buf->elem_size * index;
	memmove(target+buf->elem_size, target, (len-index)*buf->elem_size);

	buf->seq->len++;
	return target;
}

static void
gu_quick_sort(GuBuf *buf, GuOrder *order, int left, int right)
{
	int l_hold = left;
	int r_hold = right;

	void* pivot = alloca(buf->elem_size);
	memcpy(pivot,
	       &buf->seq->data[buf->elem_size * left],
	       buf->elem_size);
	while (left < right) {

		while ((order->compare(order, &buf->seq->data[buf->elem_size * right], pivot) >= 0) && (left < right))
			right--;

		if (left != right) {
			memcpy(&buf->seq->data[buf->elem_size * left],
			       &buf->seq->data[buf->elem_size * right],
			       buf->elem_size);
			left++;
		}

		while ((order->compare(order, &buf->seq->data[buf->elem_size * left], pivot) <= 0) && (left < right))
			left++;

		if (left != right) {
			memcpy(&buf->seq->data[buf->elem_size * right],
			       &buf->seq->data[buf->elem_size * left],
			       buf->elem_size);
			right--;
		}
	}
	
	memcpy(&buf->seq->data[buf->elem_size * left],
	       pivot,
           buf->elem_size);
	int index = left;
	left  = l_hold;
	right = r_hold;

	if (left < index)
		gu_quick_sort(buf, order, left, index-1);

	if (right > index)
		gu_quick_sort(buf, order, index+1, right);
}

void
gu_buf_sort(GuBuf *buf, GuOrder *order)
{
	gu_quick_sort(buf, order, 0, gu_buf_length(buf) - 1);
}

void*
gu_seq_binsearch_(GuSeq *seq, GuOrder *order, size_t elem_size, void *key)
{
	int i = 0;
	int j = seq->len-1;
	
	while (i <= j) {
		int k = (i+j) / 2;
		uint8_t* elem_p = &seq->data[elem_size * k];
		int cmp = order->compare(order, key, elem_p);

		if (cmp < 0) {
			j = k-1;
		} else if (cmp > 0) {
			i = k+1;
		} else {
			return elem_p;
		}
	}

	return NULL;
}

bool
gu_seq_binsearch_index_(GuSeq *seq, GuOrder *order, size_t elem_size,
                        void *key, size_t *pindex)
{
	size_t i = 0;
	size_t j = seq->len-1;
	
	while (i <= j) {
		size_t k = (i+j) / 2;
		uint8_t* elem_p = &seq->data[elem_size * k];
		int cmp = order->compare(order, key, elem_p);
	
		if (cmp < 0) {
			j = k-1;
		} else if (cmp > 0) {
			i = k+1;
		} else {
			*pindex = k;
			return true;
		}
	}

	*pindex = j;
	return false;
}

static void
gu_heap_siftdown(GuBuf *buf, GuOrder *order, 
                 const void *value, int startpos, int pos)
{
	while (pos > startpos) {
		int parentpos = (pos - 1) >> 1;
        void *parent = &buf->seq->data[buf->elem_size * parentpos];
        
		if (order->compare(order, value, parent) >= 0)
			break;

		memcpy(&buf->seq->data[buf->elem_size * pos], parent, buf->elem_size);
		pos = parentpos;
	}

	memcpy(&buf->seq->data[buf->elem_size * pos], value, buf->elem_size);
}

static void
gu_heap_siftup(GuBuf *buf, GuOrder *order,
               const void *value, int pos)
{
	int startpos = pos;
	int endpos = gu_buf_length(buf);

	int childpos = 2*pos + 1;
	while (childpos < endpos) {
		int rightpos = childpos + 1;
		if (rightpos < endpos &&
		    order->compare(order, 
			               &buf->seq->data[buf->elem_size * childpos],
			               &buf->seq->data[buf->elem_size * rightpos]) >= 0) {
			childpos = rightpos;
		}

		memcpy(&buf->seq->data[buf->elem_size * pos], 
		       &buf->seq->data[buf->elem_size * childpos], buf->elem_size);
		pos = childpos;
		childpos = 2*pos + 1;
   }
   
   gu_heap_siftdown(buf, order, value, startpos, pos);
}

void
gu_buf_heap_push(GuBuf *buf, GuOrder *order, void *value)
{
	gu_buf_extend(buf);
	gu_heap_siftdown(buf, order, value, 0, gu_buf_length(buf)-1);
}

void
gu_buf_heap_pop(GuBuf *buf, GuOrder *order, void* data_out)
{
	const void* last = gu_buf_trim(buf); // raises an error if empty

	if (gu_buf_length(buf) > 0) {
		memcpy(data_out, buf->seq->data, buf->elem_size);
		gu_heap_siftup(buf, order, last, 0);
	} else {
		memcpy(data_out, last, buf->elem_size);
	}
}

void
gu_buf_heap_replace(GuBuf *buf, GuOrder *order, void *value, void *data_out)
{
	gu_require(gu_buf_length(buf) > 0);

	memcpy(data_out, buf->seq->data, buf->elem_size);
	gu_heap_siftup(buf, order, value, 0);
}

void
gu_buf_heapify(GuBuf *buf, GuOrder *order)
{
	size_t middle = gu_buf_length(buf) / 2;
	void *value = alloca(buf->elem_size);
	
	for (size_t i = 0; i < middle; i++) {
		memcpy(value, &buf->seq->data[buf->elem_size * i], buf->elem_size);
		gu_heap_siftup(buf, order, value, i);
	}
}

typedef struct GuBufOut GuBufOut;
struct GuBufOut
{
	GuOutStream stream;
	GuBuf* buf;
};

static size_t
gu_buf_out_output(GuOutStream* stream, const uint8_t* src, size_t sz,
		  GuExn* err)
{
	(void) err;
	GuBufOut* bout = gu_container(stream, GuBufOut, stream);
	GuBuf* buf = bout->buf;
	gu_assert(sz % buf->elem_size == 0);
	size_t len = sz / buf->elem_size;
	gu_buf_push_n(bout->buf, src, len);
	return len;
}

static uint8_t*
gu_buf_outbuf_begin(GuOutStream* stream, size_t req, size_t* sz_out, GuExn* err)
{
	(void) req;
	(void) err;
	GuBufOut* bout = gu_container(stream, GuBufOut, stream);
	GuBuf* buf = bout->buf;
	size_t esz = buf->elem_size;
	size_t len = gu_buf_length(buf);
	gu_buf_require(buf, len + (req + esz - 1) / esz);
	size_t avail = buf->avail_len;
	gu_assert(len < avail);
	*sz_out = esz * (avail - len);
	return &buf->seq->data[len * esz];
}

static void
gu_buf_outbuf_end(GuOutStream* stream, size_t sz, GuExn* err)
{
	(void) err;
	GuBufOut* bout = gu_container(stream, GuBufOut, stream);
	GuBuf* buf = bout->buf;
	size_t len = gu_buf_length(buf);
	size_t elem_size = buf->elem_size;
	gu_require(sz % elem_size == 0);
	gu_require(sz < elem_size * (len - buf->avail_len));
	buf->seq->len = len + (sz / elem_size);
}

GuOut*
gu_buf_out(GuBuf* buf, GuPool* pool)
{
	GuBufOut* bout = gu_new(GuBufOut, pool);
	bout->stream.output = gu_buf_out_output;
	bout->stream.begin_buf = gu_buf_outbuf_begin;
	bout->stream.end_buf = gu_buf_outbuf_end;
	bout->stream.flush = NULL;
	bout->buf = buf;
	return gu_new_out(&bout->stream, pool);
}

#include <gu/type.h>

GU_DEFINE_KIND(GuSeq, GuOpaque);
GU_DEFINE_KIND(GuBuf, abstract);
