#!/usr/bin/env python3
"""
NBD server v3 — newstyle protocol + PERC H710 chunk translation.

Transformations applied on every read:
  1. Chunk translation  — skips every 5th 64KB chunk (PERC internal metadata)
  2. Superblock patch   — clears metadata_csum / gdt_csum / has_journal bits
  3. GDT reconstruction — synthesizes correct group descriptors for regions
                          that fall inside metadata chunks

Usage:
    python3 nbd_server_v3.py &
    nbd-client 127.0.0.1 10809 /dev/nbd0 -N ""
    mount -o ro,norecovery -t ext4 /dev/nbd0 /mnt/root
"""

import socket
import struct
import threading

# ── Physical layout ───────────────────────────────────────────────────────────
DEV           = '/dev/md0'
CHUNK_BYTES   = 128 * 512          # 64 KB
LV_PHYS_START = 5120000 * 512      # byte 2,621,440,000
VIRT_SIZE     = 9365766144 * 512   # from superblock block count

# ── ext4 filesystem parameters ────────────────────────────────────────────────
BSIZE      = 4096
BPG        = 32768
GDT_ENTRY  = 64
NUM_GROUPS = 35728

GDT_START_VIRT = BSIZE
GDT_END_VIRT   = BSIZE + NUM_GROUPS * GDT_ENTRY

SB_VIRT_OFFSET = 1024
SB_SIZE        = 1024

SB_INCOMPAT_OFF  = 96
SB_RO_COMPAT_OFF = 100
SB_CHECKSUM_OFF  = 1020

INCOMPAT_HAS_JOURNAL   = 0x00000004
RO_COMPAT_METADATA_CSUM = 0x00000400
RO_COMPAT_GDT_CSUM     = 0x00000010

_patched_sb   = None
_sb_lock      = threading.Lock()

# ── NBD newstyle protocol constants ──────────────────────────────────────────
NBDMAGIC         = 0x4e42444d41474943  # "NBDMAGIC"
IHAVEOPT         = 0x49484156454F5054  # "IHAVEOPT"
REPLYMAGIC       = 0x3e889045565a9

NBD_OPT_EXPORT_NAME = 1
NBD_OPT_ABORT       = 2
NBD_OPT_LIST        = 3
NBD_OPT_GO          = 7

NBD_REP_ACK         = 1
NBD_REP_SERVER      = 2
NBD_REP_ERR_UNSUP   = (1 << 31) | 1
NBD_REP_ERR_POLICY  = (1 << 31) | 2

NBD_FLAG_HAS_FLAGS   = 1 << 0
NBD_FLAG_READ_ONLY   = 1 << 1
NBD_FLAG_SEND_FLUSH  = 1 << 2
NBD_FLAG_FIXED_NEWSTYLE = 1 << 0   # client flag
NBD_FLAG_C_NO_ZEROES    = 1 << 1   # client flag

NBD_REQUEST_MAGIC = 0x25609513
NBD_REPLY_MAGIC   = 0x67446698

NBD_CMD_READ  = 0
NBD_CMD_WRITE = 1
NBD_CMD_DISC  = 2
NBD_CMD_FLUSH = 3


# ── Chunk translation ─────────────────────────────────────────────────────────

def raw_read(virt_offset, length):
    result = bytearray(length)
    pos = virt_offset
    remaining = length
    with open(DEV, 'rb') as f:
        while remaining > 0:
            group     = pos // (5 * CHUNK_BYTES)
            in_group  = pos % (5 * CHUNK_BYTES)
            chunk_idx = in_group // CHUNK_BYTES
            intra     = in_group % CHUNK_BYTES
            seg_len   = min(CHUNK_BYTES - intra, remaining)
            dst_off   = pos - virt_offset

            if chunk_idx != 4:
                phys = (LV_PHYS_START
                        + group * 4 * CHUNK_BYTES
                        + chunk_idx * CHUNK_BYTES
                        + intra)
                f.seek(phys)
                data = f.read(seg_len)
                result[dst_off:dst_off + len(data)] = data

            pos       += seg_len
            remaining -= seg_len
    return bytes(result)


# ── GDT synthesis ─────────────────────────────────────────────────────────────

def make_gdt_entry(n):
    """Build 64-byte group descriptor for group n using confirmed pattern."""
    gd = bytearray(GDT_ENTRY)
    struct.pack_into('<I', gd,  0, 1038 + n)      # block_bitmap_lo
    struct.pack_into('<I', gd,  4, 1054 + n)      # inode_bitmap_lo
    struct.pack_into('<I', gd,  8, 1070 + n * 512) # inode_table_lo
    # free counts = 0, checksum = 0 (metadata_csum cleared)
    return bytes(gd)


def patch_gdt(data, virt_offset, length):
    """Overwrite metadata-chunk zeros within the GDT with synthesized entries."""
    pos = virt_offset
    remaining = length
    while remaining > 0:
        in_group  = pos % (5 * CHUNK_BYTES)
        chunk_idx = in_group // CHUNK_BYTES
        intra     = in_group % CHUNK_BYTES
        seg_len   = min(CHUNK_BYTES - intra, remaining)
        seg_end   = pos + seg_len

        if chunk_idx == 4:
            # metadata chunk — was zeros; patch if overlaps GDT
            ol_start = max(pos, GDT_START_VIRT)
            ol_end   = min(seg_end, GDT_END_VIRT)
            if ol_start < ol_end:
                for byte_abs in range(ol_start, ol_end):
                    gdt_rel  = byte_abs - GDT_START_VIRT
                    grp      = gdt_rel // GDT_ENTRY
                    byte_in  = gdt_rel % GDT_ENTRY
                    if grp < NUM_GROUPS:
                        dst = byte_abs - virt_offset
                        entry = make_gdt_entry(grp)
                        data[dst] = entry[byte_in]

        pos       += seg_len
        remaining -= seg_len


# ── Superblock patch ──────────────────────────────────────────────────────────

def get_patched_sb():
    global _patched_sb
    with _sb_lock:
        if _patched_sb is not None:
            return _patched_sb
        sb = bytearray(raw_read(SB_VIRT_OFFSET, SB_SIZE))
        incompat  = struct.unpack_from('<I', sb, SB_INCOMPAT_OFF)[0]
        ro_compat = struct.unpack_from('<I', sb, SB_RO_COMPAT_OFF)[0]
        incompat  &= ~INCOMPAT_HAS_JOURNAL
        ro_compat &= ~(RO_COMPAT_METADATA_CSUM | RO_COMPAT_GDT_CSUM)
        struct.pack_into('<I', sb, SB_INCOMPAT_OFF,  incompat)
        struct.pack_into('<I', sb, SB_RO_COMPAT_OFF, ro_compat)
        struct.pack_into('<I', sb, SB_CHECKSUM_OFF,  0)
        _patched_sb = bytes(sb)
        print(f'[sb] patched: incompat=0x{incompat:08x} ro_compat=0x{ro_compat:08x}')
        return _patched_sb


# ── Combined read ─────────────────────────────────────────────────────────────

def read_virtual(virt_offset, length):
    data = bytearray(raw_read(virt_offset, length))

    req_end = virt_offset + length

    # Patch superblock
    sb_s = SB_VIRT_OFFSET
    sb_e = SB_VIRT_OFFSET + SB_SIZE
    if virt_offset < sb_e and req_end > sb_s:
        patched = get_patched_sb()
        cs = max(virt_offset, sb_s) - virt_offset
        ce = min(req_end, sb_e) - virt_offset
        ss = max(virt_offset, sb_s) - sb_s
        data[cs:ce] = patched[ss:ss + (ce - cs)]

    # Patch GDT (only if request overlaps GDT region)
    if virt_offset < GDT_END_VIRT and req_end > GDT_START_VIRT:
        patch_gdt(data, virt_offset, length)

    return bytes(data)


# ── NBD newstyle protocol ────────────────────────────────────────────────────

def recv_all(conn, n):
    buf = b''
    while len(buf) < n:
        chunk = conn.recv(n - len(buf))
        if not chunk:
            raise ConnectionError('client disconnected')
        buf += chunk
    return buf


def send_reply(conn, opt, reply_type, data=b''):
    conn.sendall(struct.pack('>QII', REPLYMAGIC, opt, reply_type))
    conn.sendall(struct.pack('>I', len(data)))
    if data:
        conn.sendall(data)


def send_export_info(conn, no_zeroes=False):
    """Send export size + transmission flags."""
    flags = NBD_FLAG_HAS_FLAGS | NBD_FLAG_READ_ONLY | NBD_FLAG_SEND_FLUSH
    conn.sendall(struct.pack('>Q', VIRT_SIZE))
    conn.sendall(struct.pack('>H', flags))
    if not no_zeroes:
        conn.sendall(b'\x00' * 124)


def handle_client(conn, addr):
    print(f'[nbd] connect from {addr}')
    no_zeroes = False
    try:
        # ── Fixed newstyle handshake ──────────────────────────────────────────
        # S: magic + IHAVEOPT + server flags
        conn.sendall(struct.pack('>Q', NBDMAGIC))
        conn.sendall(struct.pack('>Q', IHAVEOPT))
        server_flags = NBD_FLAG_HAS_FLAGS | (1 << 0)  # FIXED_NEWSTYLE
        conn.sendall(struct.pack('>H', server_flags))

        # C: client flags
        client_flags = struct.unpack('>I', recv_all(conn, 4))[0]
        no_zeroes = bool(client_flags & NBD_FLAG_C_NO_ZEROES)

        # ── Option haggling ───────────────────────────────────────────────────
        while True:
            opt_hdr   = recv_all(conn, 16)
            cli_magic, opt, opt_len = struct.unpack('>QII', opt_hdr)
            opt_data  = recv_all(conn, opt_len) if opt_len else b''

            if opt == NBD_OPT_EXPORT_NAME:
                # Immediate export — no reply, go straight to transmission
                send_export_info(conn, no_zeroes)
                break

            elif opt == NBD_OPT_GO:
                # Parse export name (uint32 len + name + info requests)
                name_len = struct.unpack('>I', opt_data[:4])[0]
                # Send INFO_EXPORT (type 0)
                info = struct.pack('>HQH', 0, VIRT_SIZE,
                                   NBD_FLAG_HAS_FLAGS | NBD_FLAG_READ_ONLY | NBD_FLAG_SEND_FLUSH)
                send_reply(conn, opt, NBD_REP_ACK, info)
                # After ACK for GO, enter transmission
                break

            elif opt == NBD_OPT_LIST:
                # Advertise one anonymous export
                name = b''
                send_reply(conn, opt, NBD_REP_SERVER,
                           struct.pack('>I', len(name)) + name)
                send_reply(conn, opt, NBD_REP_ACK)

            elif opt == NBD_OPT_ABORT:
                send_reply(conn, opt, NBD_REP_ACK)
                return

            else:
                send_reply(conn, opt, NBD_REP_ERR_UNSUP)

        print(f'[nbd] {addr} — entering transmission phase')

        # ── Transmission phase ────────────────────────────────────────────────
        while True:
            hdr = recv_all(conn, 28)
            magic, flags, cmd, handle, offset, length = \
                struct.unpack('>IHHQQI', hdr)

            if magic != NBD_REQUEST_MAGIC:
                print(f'[nbd] bad request magic 0x{magic:08x}')
                return

            if cmd == NBD_CMD_READ:
                try:
                    payload = read_virtual(offset, length)
                    err = 0
                except Exception as e:
                    print(f'[nbd] read err offset={offset} len={length}: {e}')
                    payload = b'\x00' * length
                    err = 0
                conn.sendall(struct.pack('>IIQ', NBD_REPLY_MAGIC, err, handle))
                conn.sendall(payload)

            elif cmd in (NBD_CMD_DISC,):
                print(f'[nbd] {addr} disconnected')
                return

            elif cmd == NBD_CMD_FLUSH:
                conn.sendall(struct.pack('>IIQ', NBD_REPLY_MAGIC, 0, handle))

            else:
                # Write or unknown — return EPERM
                conn.sendall(struct.pack('>IIQ', NBD_REPLY_MAGIC, 1, handle))

    except (ConnectionError, BrokenPipeError, ConnectionResetError):
        print(f'[nbd] {addr} dropped')
    except Exception as e:
        print(f'[nbd] {addr} error: {e}')
    finally:
        conn.close()


def main():
    print('PERC H710 recovery NBD server v3 (newstyle protocol)')
    print(f'  device     : {DEV}')
    print(f'  lv start   : byte {LV_PHYS_START}')
    print(f'  virtual sz : {VIRT_SIZE // (1024**3):.1f} GB')
    print(f'  features   : chunk-skip + sb-patch + gdt-synth + newstyle')
    print()

    srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    srv.bind(('127.0.0.1', 10809))
    srv.listen(5)
    print('Listening on 127.0.0.1:10809')
    print()
    print('Connect with:')
    print('  nbd-client 127.0.0.1 10809 /dev/nbd0 -N ""')
    print('  mount -o ro,norecovery -t ext4 /dev/nbd0 /mnt/root')
    print()

    while True:
        conn, addr = srv.accept()
        threading.Thread(target=handle_client, args=(conn, addr),
                         daemon=True).start()


if __name__ == '__main__':
    main()
