//go:build postgres
// +build postgres

package main

/*
#cgo pkg-config: libpq
#include <stdlib.h>
#include <libpq-fe.h>
*/
import "C"
import (
	"fmt"
	"strconv"
	"sync"
	"unsafe"

	"jobarranger2/src/libs/golibs/database"
)

type psqlDatabase struct {
	config   *database.DBConfig
	connPool []database.DBConnection
}

func NewDB(config *database.DBConfig) (database.Database, error) {
	var conns []database.DBConnection
	var conn database.DBConnection
	var err error

	for i := 0; i < config.MaxConCount(); i++ {
		conn = &psqlDBConnection{}

		err = conn.Reconnect(config)
		if err != nil {
			return nil, err
		}

		conns = append(conns, conn)
	}

	return &psqlDatabase{
		config:   config,
		connPool: conns,
	}, nil
}

func (db psqlDatabase) Close() {
	for i := 0; i < db.GetPoolSize(); i++ {
		db.connPool[i].Close()
	}
}

func (db psqlDatabase) IsConnClosed(index int) bool {
	if db.GetPoolSize() <= 0 {
		return false
	}

	if index < 0 && index < len(db.connPool) {
		return false
	}

	return db.connPool[index].IsClosed()
}

func (db psqlDatabase) ConnReconnect(index int) error {
	if db.GetPoolSize() <= 0 {
		return database.ErrNoDBConn
	}

	if index < 0 && index < len(db.connPool) {
		return database.ErrIndexOutOfBound
	}

	return db.connPool[index].Reconnect(db.config)
}

func (db psqlDatabase) GetPoolSize() int {
	return len(db.connPool)
}

func (db psqlDatabase) GetDBConfig() *database.DBConfig {
	return db.config
}

func (db psqlDatabase) GetConnFromPool(index int) (database.DBConnection, error) {
	if db.GetPoolSize() <= 0 {
		return nil, database.ErrNoDBConn
	}

	if index < 0 && index < len(db.connPool) {
		return nil, database.ErrIndexOutOfBound
	}

	conn, err := db.connPool[index].StartSession()

	return conn, err
}

func (db *psqlDatabase) GetConn() (database.DBConnection, error) {
	if len(db.connPool) > 0 {
		return database.GetConnFromPool(db)
	}

	conn := &psqlDBConnection{}
	err := conn.Reconnect(db.config)
	if err != nil {
		return nil, err
	}

	return conn, nil
}

type psqlDBConnection struct {
	conn            *C.PGconn
	isInTransaction bool
	dbErrCode       string
	dbErrMessage    string
	dbResultStatus  int
	hasSession      bool
	mu              sync.Mutex
}

func (dbConn *psqlDBConnection) IsClosed() bool {
	return dbConn.conn == nil
}

func (dbConn *psqlDBConnection) StartSession() (database.DBConnection, error) {
	if !dbConn.mu.TryLock() {
		return nil, database.ErrDBConnLocked
	}

	dbConn.hasSession = true

	return dbConn, nil
}

func (dbConn *psqlDBConnection) EndSession() error {
	dbConn.mu.TryLock() // to prevent panic error on unlocking unlocked mutex
	defer func() {
		dbConn.hasSession = false // prevents calling db function after EndSession
		dbConn.mu.Unlock()
	}()

	if dbConn.conn == nil {
		return database.ErrDBConnNil
	}

	C.PQreset(dbConn.conn)

	return nil
}

// Reconnect() can be used in DB connection retry cases.
func (dbConn *psqlDBConnection) Reconnect(config *database.DBConfig) error {
	if config == nil {
		return database.ErrDBConfigNil
	}

	var cKeys []*C.char
	var cVals []*C.char

	// clean up C char pointers
	defer func() {
		for _, ptr := range cKeys {
			if ptr != nil {
				C.free(unsafe.Pointer(ptr))
			}
		}
		for _, ptr := range cVals {
			if ptr != nil {
				C.free(unsafe.Pointer(ptr))
			}
		}
	}()

	if config.TLSMode() != "" {
		cKeys = append(cKeys, C.CString("sslmode"))

		switch config.TLSMode() {
		case "required":
			cVals = append(cVals, C.CString("require"))
		case "verify_ca":
			cVals = append(cVals, C.CString("verify-ca"))
		default:
			cVals = append(cVals, C.CString("verify-full"))
		}

		if path := config.TLSCertFile(); path != "" {
			cKeys = append(cKeys, C.CString("sslcert"))
			cVals = append(cVals, C.CString(path))
		}
		if path := config.TLSKeyFile(); path != "" {
			cKeys = append(cKeys, C.CString("sslkey"))
			cVals = append(cVals, C.CString(path))
		}
		if path := config.TLSCaFile(); path != "" {
			cKeys = append(cKeys, C.CString("sslrootcert"))
			cVals = append(cVals, C.CString(path))
		}
	}

	// Basic connection info
	if config.Hostname() != "" {
		cKeys = append(cKeys, C.CString("host"))
		cVals = append(cVals, C.CString(config.Hostname()))
	}
	if config.DBName() != "" {
		cKeys = append(cKeys, C.CString("dbname"))
		cVals = append(cVals, C.CString(config.DBName()))
	}
	if config.User() != "" {
		cKeys = append(cKeys, C.CString("user"))
		cVals = append(cVals, C.CString(config.User()))
	}
	if config.Password() != "" {
		cKeys = append(cKeys, C.CString("password"))
		cVals = append(cVals, C.CString(config.Password()))
	}

	dbPort := config.Port()
	if dbPort > 0 {
		cKeys = append(cKeys, C.CString("port"))
		cVals = append(cVals, C.CString(strconv.Itoa(dbPort)))
	}

	// Null-terminate
	cKeys = append(cKeys, nil)
	cVals = append(cVals, nil)

	pqConn := C.PQconnectdbParams(
		&cKeys[0],
		&cVals[0],
		0,
	)

	if C.PQstatus(pqConn) != C.CONNECTION_OK {
		dbConn.dbErrMessage = C.GoString(C.PQerrorMessage(pqConn))
		C.PQfinish(pqConn)

		return fmt.Errorf("DB down, psql_err_msg: %s", dbConn.DBErrMessage())
	}

	dbConn.conn = pqConn

	return nil
}

func (dbConn *psqlDBConnection) Close() {
	if dbConn.conn != nil {
		C.PQfinish(dbConn.conn)
		dbConn.conn = nil
	}
}

func (dbConn *psqlDBConnection) Begin() error {
	if dbConn.isInTransaction {
		return database.ErrDuplicatedDBTransaction
	}

	if _, err := dbConn.execute("begin;"); err != nil {
		return err
	}

	dbConn.isInTransaction = true

	return nil
}

func (dbConn *psqlDBConnection) Commit() error {
	if !dbConn.isInTransaction {
		return database.ErrNoDBTransaction
	}

	if _, err := dbConn.execute("commit;"); err != nil {
		return err
	}

	dbConn.isInTransaction = false

	return nil
}

func (dbConn *psqlDBConnection) Rollback() error {
	if !dbConn.isInTransaction {
		return database.ErrNoDBTransaction
	}

	if _, err := dbConn.execute("rollback;"); err != nil {
		return err
	}

	dbConn.isInTransaction = false

	return nil
}

func (dbConn *psqlDBConnection) IsAlive() bool {
	if dbConn.conn != nil {
		return C.PQstatus(dbConn.conn) == C.CONNECTION_OK
	}

	return false
}

func (dbConn *psqlDBConnection) execute(format string, arg ...any) (int, error) {
	// prevent calling functions after EndSession()
	if !dbConn.hasSession {
		return 0, database.ErrNoDBSession
	}

	if dbConn.conn == nil {
		return 0, database.ErrDBConnNil
	}

	var affectedRowsCount int
	var err error

	sqlC := C.CString(fmt.Sprintf(format, arg...))
	defer C.free(unsafe.Pointer(sqlC))

	result := C.PQexec(dbConn.conn, sqlC)
	if result == nil {
		return 0, database.ErrDBResultNil
	}
	defer C.PQclear(result)

	status := C.PQresultStatus(result)
	if status != C.PGRES_COMMAND_OK && status != C.PGRES_TUPLES_OK {
		dbConn.dbErrCode = C.GoString(C.PQresultErrorField(result, C.PG_DIAG_SQLSTATE))
		dbConn.dbErrMessage = C.GoString(C.PQresultErrorMessage(result))
		dbConn.dbResultStatus = int(status)

		return 0, fmt.Errorf("psql_err_msg: %s, psql_err_code: %s", dbConn.DBErrMessage(), dbConn.DBErrCode())
	}

	tuples := C.GoString(C.PQcmdTuples(result))

	// ignore for queries that dont update data
	if tuples != "" {
		affectedRowsCount, err = strconv.Atoi(tuples)

		if err != nil {
			return 0, fmt.Errorf("could not get affected rows count")
		}
	} else {
		affectedRowsCount = 0
	}

	return affectedRowsCount, nil
}

func (dbConn *psqlDBConnection) Execute(format string, arg ...any) (int, error) {
	if !dbConn.isInTransaction {
		return 0, database.ErrNoDBTransaction
	}

	return dbConn.execute(format, arg...)
}

func (dbConn *psqlDBConnection) Select(format string, arg ...any) (database.DBresult, error) {
	// prevent calling functions after EndSession()
	if !dbConn.hasSession {
		return nil, database.ErrNoDBSession
	}

	if dbConn.conn == nil {
		return nil, database.ErrDBConnNil
	}

	sqlC := C.CString(fmt.Sprintf(format, arg...))
	defer C.free(unsafe.Pointer(sqlC))

	result := C.PQexec(dbConn.conn, sqlC)
	if result == nil {
		return nil, database.ErrDBResultNil
	}

	status := C.PQresultStatus(result)
	if status != C.PGRES_TUPLES_OK {
		dbConn.dbErrCode = C.GoString(C.PQresultErrorField(result, C.PG_DIAG_SQLSTATE))
		dbConn.dbErrMessage = C.GoString(C.PQresultErrorMessage(result))
		dbConn.dbResultStatus = int(status)
		C.PQclear(result)

		return nil, fmt.Errorf("psql_err_msg: %s, psql_err_code: %s", dbConn.DBErrMessage(), dbConn.DBErrCode())
	}

	rowCount := int(C.PQntuples(result))
	colCount := int(C.PQnfields(result))
	colNames := make([]string, colCount)
	for i := 0; i < colCount; i++ {
		colNames[i] = C.GoString(C.PQfname(result, C.int(i)))
	}

	return &PsqlDBresult{
		pgResult:     result,
		totalRowNo:   rowCount,
		currentRowNo: 0,
		columnNames:  colNames,
	}, nil
}

func (database *psqlDBConnection) DBErrCode() string {
	return database.dbErrCode
}

func (database *psqlDBConnection) DBErrMessage() string {
	return database.dbErrMessage
}

func (database *psqlDBConnection) DBResultStatus() int {
	return database.dbResultStatus
}

type PsqlDBresult struct {
	pgResult     *C.PGresult
	totalRowNo   int
	columnNames  []string
	currentRowNo int
	mu           sync.Mutex
}

// PsqlDBresult utility functions
func (dbResult *PsqlDBresult) Fetch() (map[string]string, error) {
	dbResult.mu.Lock()
	defer dbResult.mu.Unlock()

	if dbResult.pgResult == nil {
		return nil, database.ErrDBResultNil
	}

	if dbResult.currentRowNo >= dbResult.totalRowNo {
		return nil, database.ErrNoTableRows
	}

	// Process current row
	rowData := make(map[string]string)
	for colIdx, colName := range dbResult.columnNames {
		// Get value and length (avoids relying on NULL termination)
		valPtr := C.PQgetvalue(dbResult.pgResult, C.int(dbResult.currentRowNo), C.int(colIdx))
		length := int(C.PQgetlength(dbResult.pgResult, C.int(dbResult.currentRowNo), C.int(colIdx)))
		isNull := C.PQgetisnull(dbResult.pgResult, C.int(dbResult.currentRowNo), C.int(colIdx))

		if isNull == 1 {
			rowData[colName] = ""
		} else {
			rowData[colName] = C.GoStringN(valPtr, C.int(length))
		}
	}

	dbResult.currentRowNo++ // Advance to next row

	return rowData, nil
}

func (dbResult *PsqlDBresult) HasNextRow() bool {
	dbResult.mu.Lock()
	defer dbResult.mu.Unlock()

	return dbResult.currentRowNo < dbResult.totalRowNo
}

func (dbResult *PsqlDBresult) Free() {
	dbResult.mu.Lock()
	defer dbResult.mu.Unlock()

	if dbResult.pgResult != nil {
		C.PQclear(dbResult.pgResult)
		dbResult.pgResult = nil
	}
}
