//go:build windows
// +build windows

/*
** Job Arranger for ZABBIX
** Copyright (C) 2025 Daiwa Institute of Research Ltd. All Rights Reserved.
**
** This program is free software; you can redistribute it and/or modify
** it under the terms of the GNU General Public License as published by
** the Free Software Foundation; either version 2 of the License, or
** (at your option) any later version.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
** GNU General Public License for more details.
**
** You should have received a copy of the GNU General Public License
** along with this program; if not, write to the Free Software
** Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
**/

package agentutils

import (
	"encoding/gob"
	"fmt"
	"io"
	"net"
	"os"
	"syscall"
	"time"
	"unsafe"

	"golang.org/x/sys/windows"

	"jobarranger2/src/libs/golibs/common"
)

var (
	ws2_32              = windows.NewLazyDLL("ws2_32.dll")
	getpeername         = ws2_32.NewProc("getpeername")
	wsaSocket           = ws2_32.NewProc("WSASocketW")
	procSend            = ws2_32.NewProc("send")
	procRecv            = ws2_32.NewProc("recv")
	procWSAGetLastError = ws2_32.NewProc("WSAGetLastError")
)

func RecreateSocket(socketFilePath string) (*common.NetConnection, error) {

	file, err := os.Open(socketFilePath)
	if err != nil {
		return nil, err
	}
	defer file.Close()

	var protoInfo windows.WSAProtocolInfo
	dec := gob.NewDecoder(file)
	err = dec.Decode(&protoInfo)
	if err != nil {
		return nil, err
	}

	var wsaData windows.WSAData
	errCode := windows.WSAStartup(uint32(0x202), &wsaData)
	if errCode != nil {
		return nil, fmt.Errorf("WSAStartup failed: %v", errCode)
	}
	defer windows.WSACleanup()

	// Recreate socket
	sock, _, err := wsaSocket.Call(
		uintptr(syscall.AF_INET),
		uintptr(syscall.SOCK_STREAM),
		uintptr(syscall.IPPROTO_TCP),
		uintptr(unsafe.Pointer(&protoInfo)),
		0,
		0,
	)
	if sock == ^uintptr(0) {
		return nil, fmt.Errorf("WSASocketW failed: %v", err)
	}
	sockHandle := windows.Handle(sock)
	conn := &SocketConn{
		S: windows.Handle(sockHandle),
	}

	netConn := &common.NetConnection{
		Conn: conn,
	}

	return netConn, nil
}

func GetTarWriter(netConn common.NetConnection) io.Writer {
	sc, _ := netConn.Conn.(*SocketConn)
	return sc
}

// TCP Socket for windows that implements net.Conn interface
type SocketConn struct {
	S windows.Handle
}

const (
	SO_SNDTIMEO = 0x1005
)

func (c *SocketConn) Read(b []byte) (int, error) {
	var buf windows.WSABuf
	buf.Len = uint32(len(b))
	buf.Buf = &b[0]
	var flags uint32
	var n uint32
	err := windows.WSARecv(c.S, &buf, 1, &n, &flags, nil, nil)

	if err != nil {
		if err == windows.WSAETIMEDOUT {
			return int(n), os.ErrDeadlineExceeded
		}
		return int(n), os.NewSyscallError("WSARecv", err)
	}
	if n == 0 {
		return 0, io.EOF
	}
	return int(n), nil
}

func (c *SocketConn) Write(buf []byte) (int, error) {
	total := 0
	for total < len(buf) {
		n, _, err := procSend.Call(
			uintptr(c.S),
			uintptr(unsafe.Pointer(&buf[total])),
			uintptr(len(buf)-total),
			0,
		)
		if err != syscall.Errno(0) {
			return total, err
		}
		if n == 0 {
			return total, fmt.Errorf("send returned 0")
		}
		total += int(n)
	}
	return total, nil

}

func (c *SocketConn) Close() error {
	return os.NewSyscallError("closesocket", windows.Closesocket(c.S))
}

func (c *SocketConn) LocalAddr() net.Addr { return nil } // you can use syscall.Getsockname
func (c *SocketConn) RemoteAddr() net.Addr {
	var rsa syscall.RawSockaddrAny
	var addrLen int32 = int32(unsafe.Sizeof(rsa))

	r1, _, _ := getpeername.Call(
		uintptr(c.S),
		uintptr(unsafe.Pointer(&rsa)),
		uintptr(unsafe.Pointer(&addrLen)),
	)
	if r1 == ^uintptr(0) {
		return nil
	}

	sa, err := anyToSockaddr(&rsa)
	if err != nil {
		return nil
	}

	return sockaddrToNetAddr(sa)
}

func anyToSockaddr(rsa *syscall.RawSockaddrAny) (windows.Sockaddr, error) {
	switch rsa.Addr.Family {
	case syscall.AF_INET:
		pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(rsa))
		port := int(pp.Port<<8 | pp.Port>>8) // ntohs
		return &windows.SockaddrInet4{Addr: pp.Addr, Port: port}, nil

	case syscall.AF_INET6:
		pp := (*syscall.RawSockaddrInet6)(unsafe.Pointer(rsa))
		port := int(pp.Port<<8 | pp.Port>>8)
		return &windows.SockaddrInet6{Addr: pp.Addr, Port: port, ZoneId: pp.Scope_id}, nil
	}
	return nil, fmt.Errorf("unsupported family")
}

func sockaddrToNetAddr(sa windows.Sockaddr) net.Addr {
	switch v := sa.(type) {
	case *windows.SockaddrInet4:
		return &net.TCPAddr{IP: v.Addr[:], Port: v.Port}
	case *windows.SockaddrInet6:
		var zone string
		if v.ZoneId != 0 { // only set Zone if nonzero
			zone = zoneToString(v.ZoneId)
		}
		return &net.TCPAddr{IP: v.Addr[:], Port: v.Port, Zone: zone}
	}
	return nil
}

func zoneToString(zoneId uint32) string {
	ifi, err := net.InterfaceByIndex(int(zoneId))
	if err == nil {
		return ifi.Name
	}
	return fmt.Sprint(zoneId)
}

func (c *SocketConn) SetDeadline(t time.Time) error { return nil }

func (c *SocketConn) SetReadDeadline(t time.Time) error {
	if t.IsZero() {
		// no deadline → disable timeout
		var zero uint32
		return windows.SetsockoptInt(c.S, windows.SOL_SOCKET, windows.SO_RCVTIMEO, int(zero))
	}

	// calculate duration
	d := time.Until(t)
	if d <= 0 {
		// deadline already exceeded
		d = time.Millisecond // minimal nonzero
	}
	return windows.SetsockoptInt(c.S, windows.SOL_SOCKET, windows.SO_RCVTIMEO, int(d.Milliseconds()))
}

func (c *SocketConn) SetWriteDeadline(t time.Time) error {
	if t.IsZero() {
		var zero uint32
		return windows.SetsockoptInt(c.S, windows.SOL_SOCKET, SO_SNDTIMEO, int(zero))
	}

	d := time.Until(t)
	if d <= 0 {
		d = time.Millisecond
	}
	return windows.SetsockoptInt(c.S, windows.SOL_SOCKET, SO_SNDTIMEO, int(d.Milliseconds()))
}
