/*
** Job Arranger for ZABBIX
** Copyright (C) 2025 Daiwa Institute of Research Ltd. All Rights Reserved.
**
** This program is free software; you can redistribute it and/or modify
** it under the terms of the GNU General Public License as published by
** the Free Software Foundation; either version 2 of the License, or
** (at your option) any later version.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
** GNU General Public License for more details.
**
** You should have received a copy of the GNU General Public License
** along with this program; if not, write to the Free Software
** Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
**/
package jatcp

import (
	"errors"
	"fmt"
	"net"

	"jobarranger2/src/libs/golibs/common"
)

type TcpServer struct {
	listener   net.Listener
	allowedIPs map[string]struct{}
}

func CreateTcpServer(hostname string, listenPort int, allowedIPs []string) (*TcpServer, error) {
	address := fmt.Sprintf("%s:%d", hostname, listenPort)
	fmt.Printf("%s:%d", hostname, listenPort)
	allowedIPsMap := make(map[string]struct{})

	// assign allowIPs into created map
	for _, ipStr := range allowedIPs {
		ip := net.ParseIP(ipStr)
		if ip == nil {
			return nil, fmt.Errorf("invalid IP address: %s", ipStr)
		}

		// assign empty struct which takes zero memory
		allowedIPsMap[ip.String()] = struct{}{}
	}

	listener, err := net.Listen("tcp", address)
	if err != nil {
		return nil, err
	}

	return &TcpServer{
		listener:   listener,
		allowedIPs: allowedIPsMap,
	}, nil
}

func (server *TcpServer) Accept() (*common.NetConnection, error) {
	if server.listener == nil {
		return nil, errors.New("listener is nil")
	}

	connection, err := server.listener.Accept()
	if err != nil {
		return nil, err
	}

	// get the incoming ip address
	remoteAddr := connection.RemoteAddr().String()
	remoteHost, _, err := net.SplitHostPort(remoteAddr)
	if err != nil {
		connection.Close()
		return nil, err
	}

	remoteIP := net.ParseIP(remoteHost)
	if remoteIP == nil {
		connection.Close()
		return nil, fmt.Errorf("invalid IP address: %s", remoteHost)
	}

	// reject connection if the incoming ip is not included in server.allowedIPs
	if len(server.allowedIPs) > 0 {
		if _, allowed := server.allowedIPs[remoteIP.String()]; !allowed {
			connection.Close()
			return nil, fmt.Errorf("ip: %s is not allowed", remoteIP.String())
		}
	}

	// if the incoming ip is included in server.allowedIPs OR len(server.allowedIPs) <= 0, accepts the connection
	return &common.NetConnection{
		Conn: connection,
	}, nil
}

func (server *TcpServer) Close() error {
	if server.listener == nil {
		return errors.New("listener is nil")
	}

	return server.listener.Close()
}
