#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright the James Browning
# SPDX-License-Identifier: CC BY-NC-SA 4.0
"""Serve SNTP time while aspiring to MS-SNTP."""
from __future__ import division, print_function
import logging
import math
import select
import socket
import struct
import sys
import time

VERSION = "ms-sntpd v0.1"
JAN_1970 = 2208988800
BIG = 4096 * 1048576
MID = -1 + (2048 * 1048576)
FORMAT_NTP = "!BBBbIIIqqqq"
FORMAT_SIGND_REQ = "!3I2HI"
FORMAT_SIGND_RESP = "!4I"
REFID_ORPHAN = 0x7F000001
REFID_NAK_CRYP = 0x52435059
LEN_NTP = 48
# LEN_NTP_NAK = 52
LEN_NTP_SMB_B = 68
LEN_NTP_SMB_E = 120
TIMEOUT_SMB = 0.4

SIGN_TO_CLIENT = 0
ASK_SERVER_TO_SIGN = 1
CHECK_SERVER_SIGNATURE = 2
SIGNING_SUCCESS = 3
SIGNING_FAILURE = 4
SIGND_PROTO_VER = 0


def timelfp():
    """Get signed 64-bit integer of NTP epoch time."""
    fraction, seconds = math.modf(time.time())
    seconds = int(seconds) + JAN_1970
    while seconds >= MID:  # Probably not the right method
        seconds -= BIG
    lfp = int(seconds) * BIG
    return lfp | int(fraction * BIG)


def reponse(magic, one, two=None):
    """Fabricate an SNTP respopnse packet."""
    if two is None:
        two = timelfp()
    new_magic = 0x27
    if magic & 0x3F in [0x23, 0b011011, 0b010011]:
        new_magic = magic + 1
    yargs = [
        FORMAT_NTP,
        new_magic,  # leap version mode
        10,  # stratum
        7,  # poll
        -10,  # precision
        0,  # root delay
        0,  # root dispersion
        REFID_ORPHAN,  # refid/kiss
        two,  # reference time
        one,  # origin time
        two,  # receive time
        timelfp(),  # send time
    ]
    return struct.pack(*yargs)


class SNTP:
    """Run an SNTP server that aspires to MS-SNTP."""

    pkt_id = 0
    outstading = []
    queued = []
    outgoing = []
    ntpd = None
    signd = None

    def __init__(self, args):
        """Prep variables and connect to sockets."""
        self.timeout = args.timeout
        try:
            self.signd = socket.socket(
                socket.AF_UNIX, socket.SOCK_STREAM
            )
            self.signd.settimeout(args.timeout)
            self.signd.connect(args.signd + "/socket")
        except PermissionError:
            log.error(
                'can not connect to "socket" in folder %r', args.signd
            )
            self.signd.close()
            self.signd = None
        try:
            self.ntpd = socket.socket(
                socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDP
            )
            self.ntpd.bind(("::", args.port))
        except PermissionError:
            log.critical("can not bind to UDP6 port %d", args.port)
            sys.exit(1)

    def __del__(self):
        """Send all queued packets when vanishing."""
        if self.ntpd:
            while self.outstading:
                self.ntp_egress()

    def ntp_ingress(self):
        """Write replies adding to outbox or noatary pile."""
        ipacket, raddress = self.ntpd.recvfrom(512)
        ilen = len(ipacket)
        if ilen not in [LEN_NTP, LEN_NTP_SMB_B, LEN_NTP_SMB_E]:
            log.warning("Got unexpected input packet length %d", ilen)
            return
        pieces = struct.unpack(FORMAT_NTP, ipacket)
        imagic = pieces[0] & 0b00111111
        it4 = pieces[-1]
        phase2 = reponse(imagic, it4)
        if LEN_NTP == ilen:
            self.outgoing.append([raddress, phase2])
        elif self.signd:
            skey = ipacket[48:52]
            self.queued.append([raddress, imagic, phase2, skey])

    def ntp_egress(self):
        """Senc responses back to the original sender."""
        radress, phase2 = self.outgoing.pop(0)
        self.ntpd.sendto(phase2, radress)

    def bing(self):
        """Dispatch packets to their destination forever."""
        ntpd = self.ntpd.fileno()
        possible = [ntpd]
        if self.signd:
            signd = self.signd.fileno()
            possible = [ntpd, signd]
        try:
            while True:
                timeout = TIMEOUT_SMB if self.outstading else 0
                rlist, wlist, xlist = select.select(
                    possible, possible, possible, timeout
                )
                if ntpd in xlist:
                    log.critical(
                        "Sokmething seems to have happened to the NTP socket"
                    )
                    sys.exit(1)
                if ntpd in rlist:
                    self.ntp_ingress()
                if self.signd:
                    if self.queued and signd in wlist:
                        self.signd_ingress()
                    if signd in xlist:
                        self.signd_timeout()
                    if self.outstading and signd in rlist:
                        self.signd_egress()
                if self.outgoing and ntpd in wlist:
                    self.ntp_egress()
        except KeyboardInterrupt:
            print("")

    def signd_timeout(self):
        """Collect the dead letters and send them on."""
        now = time.time()
        while True:
            pkt = self.outstading[0]
            expire, _pkt_id, raddress, _imagic, phase2, skey = pkt
            if expire > now:
                break
            pkt_out = phase2[:40] + struct.pack("!Q", timelfp()) + skey
            pkt_out[28:32] = struct.pack("!I", timelfp())
            self.outgoing.append([raddress, pkt_out])
            self.outstading.pop(0)
            log.warning("expired packet to signd")

    def signd_egress(self):
        """Move signed packet to the outbox."""
        pkt_in = self.signd.recv(512)
        size, _version, _operation, pkt_id = struct.unpack(
            FORMAT_SIGND_RESP, pkt_in[:16]
        )
        if _version != SIGND_PROTO_VER or _operation != SIGNING_SUCCESS:
            log.warning(
                "got op%d v%d when expecting op%d v%d",
                _operation,
                _version,
                SIGNING_SUCCESS,
                SIGND_PROTO_VER,
            )
            return
        if size not in [68, 84, 136]:  # 16 is not wanted
            log.warning("Got unexpected length %d from signd", size)
            return
        pkt_out = pkt_in[16:]
        # if len(self.outstading) != 1:
        #     return
        for index, outstanding in enumerate(self.outstading):
            if pkt_id == outstanding[1]:
                self.outstading.pop(index)
                raddress = outstanding[-2]
                self.outgoing.append([raddress, pkt_out])
                log.info("send queueing ms-sntp return packet.")
                return
        log.warning("did not send queue ms-sntp return packet")

    def signd_ingress(self):
        """Send unsigned packets to the notary."""
        mid_list = self.queued.pop(0)
        self.pkt_id = (self.pkt_id + 1) & 0xFFFFFFFF
        header = struct.pack(
            FORMAT_SIGND_REQ,
            64,  # packet size
            SIGND_PROTO_VER,  # protocol version (0)
            self.pkt_id,  # packet ID
            ASK_SERVER_TO_SIGN,  # operation sign message (0)
            socket.htonl(mid_list[-1]),  # key ID
        )
        self.signd.send(header + mid_list[-1] + mid_list[-2])
        expire = time.time() + self.timeout
        self.outgoing.append([expire, self.pkt_id] + mid_list)


log = logging.getLogger(__name__)
if "__main__" == __name__:
    import argparse

    parser = argparse.ArgumentParser(
        description="Serve NTP time ineffectively",
    )
    parser.add_argument(
        "-p",
        "--port",
        help="UDP port to serve time on (if not %(default)d)",
        type=int,
        default=123,
    )
    parser.add_argument(
        "-s",
        "--signd",
        help="folder to look for signd socket in (if not %(default)r)",
        default="/var/lib/samba/ntp_signd",
    )
    parser.add_argument(
        "-t",
        "--timeout",
        type=float,
        default=TIMEOUT_SMB,
        help="timeout on the signd socket (if not %(default)rs)",
    )
    parser.add_argument(
        "-v",
        "--version",
        help="display the version and exit",
        version=VERSION,
        action="version",
    )
    args_in = parser.parse_args()
    classy = SNTP(args_in)
    classy.bing()

"""Changelog

2025-02-23 v0.1
- seemingly better timing than NTPsec
- some safety checks
- bush league logging

2025-02-22 unversioned
- thing should have worked...
"""
