"""
evaluate.py

Evaluation script for comparing predicted binding pockets against ground-truth labels.
Computes MCC, precision, recall, F1, DCA, and DCC for each predicted binding pocket.

Author: [Xiaoyu An]
Date: 2025-06-01

Dependencies:
    - numpy
    - scikit-learn
    - biopython
    - csv
    - argparse
"""

import os
import re
import csv
import warnings
import argparse
import numpy as np
from Bio.PDB import PDBParser
from sklearn.metrics import (
    matthews_corrcoef, precision_score, recall_score, f1_score
)


def extract_label(label):
    label_bs = []
    label_parts = label.strip().split(";")  # Each part is a binding pocket
    for pocket_label in label_parts:
        pocket_bs = []
        entries = pocket_label.strip().split()
        chain = None
        for entry in entries:
            if "_" in entry:
                chain, residue = entry.split("_")
            else:
                residue = entry  # Try to extract from raw residue if possible
            matches = re.findall(r'([A-Za-z]+)?(-?\d+)', residue)
            if matches:
                for _, num in matches:
                    pocket_bs.append(f"{chain}_{num}")
            else:
                # fallback: treat entire residue as number
                pocket_bs.append(f"{chain}_{residue}")
        label_bs.append(pocket_bs)
    return label_bs


def to_one_hot(residue_list, chain_lengths):
    total_length = sum(chain_lengths.values())
    one_hot = np.zeros(total_length, dtype=int)

    # Compute the starting index for each chain
    chain_offsets = {}
    offset = 0
    for chain_id, length in chain_lengths.items():
        chain_offsets[chain_id] = offset
        offset += length

    # Fill in the one-hot encoding
    for res in residue_list:
        chain, res_id = res.split("_")
        res_id = re.findall(r'\d+', res_id)[0]
        res_index = int(res_id) - 1  # assume residue index starts from 1
        if chain not in chain_offsets:
            warnings.warn(f"Chain {chain} not found in chain_lengths.")
            continue
        one_hot[chain_offsets[chain] + res_index] = 1

    return one_hot


def get_atoms(model, residue_ids):
    coords = []
    targets = {(c, int(re.findall(r"-?\d+", r)[0])) for c, r in (s.split('_') for s in residue_ids)}
    for chain in model:
        for res in chain.get_residues():
            if (chain.id, res.id[1]) in targets:
                coords.extend(atom.get_coord() for atom in res.get_atoms())
    return np.array(coords)


def compute_metrics(pred_bs, label_bs, model):
    num_pred, num_label = len(pred_bs), len(label_bs)
    chain_lengths = {c.id: len(list(c.get_residues())) for c in model}

    matrices = {k: np.zeros((num_pred, num_label)) for k in (
        'mcc', 'precision', 'recall', 'f1', 'dca', 'dcc')}
    coords = {}

    for i, pred in enumerate(pred_bs):
        pred_vec = to_one_hot(pred, chain_lengths)
        coords[f"pred_{i}"] = get_atoms(model, pred)
        for j, label in enumerate(label_bs):
            label_vec = to_one_hot(label, chain_lengths)
            coords[f"label_{j}"] = get_atoms(model, label)

            matrices['mcc'][i, j] = matthews_corrcoef(label_vec, pred_vec)
            matrices['precision'][i, j] = precision_score(label_vec, pred_vec, zero_division=0)
            matrices['recall'][i, j] = recall_score(label_vec, pred_vec, zero_division=0)
            matrices['f1'][i, j] = f1_score(label_vec, pred_vec, zero_division=0)

            pred_c, label_c = np.mean(coords[f"pred_{i}"], axis=0), np.mean(coords[f"label_{j}"], axis=0)
            dca = np.min(np.linalg.norm(coords[f"label_{j}"] - pred_c, axis=1))
            dcc = np.linalg.norm(label_c - pred_c)
            matrices['dca'][i, j] = int(dca <= 4.0)
            matrices['dcc'][i, j] = int(dcc <= 4.0)

    return greedy_match(matrices, num_pred, num_label)


def greedy_match(matrices, num_pred, num_label):
    scores = {k: np.zeros(num_label) for k in matrices}
    used_pred, used_label = set(), set()

    for _ in range(min(num_pred, num_label)):
        best = -np.inf
        for i in range(num_pred):
            if i in used_pred: continue
            for j in range(num_label):
                if j in used_label: continue
                if matrices["mcc"][i, j] > best:
                    best, best_i, best_j = matrices["mcc"][i, j], i, j
        used_pred.add(best_i)
        used_label.add(best_j)
        for k in matrices:
            scores[k][best_j] = matrices[k][best_i, best_j]

    result = {k: round(np.mean(v), 3) for k, v in scores.items()}
    return result

def main(args):
    print("Start processing...")
    print("Loading files...")
    # Load all required files
    label_dict = {}
    target_n_dict = {}
    with open(args.label_csv) as f:
        for row in csv.DictReader(f):
            pid = row["PDB ID"]
            label_dict[pid] = row["Label"]    # Dataset holo243 we recommond: Cross_chain_label & All_ligand
            target_n_dict[pid] = len(row["Ligand"].split(";"))
        
    with open(args.protein_list) as f:
        protein_ids = [line.strip() for line in f]

    pred_dict = {}
    with open(args.pred_txt) as f:
        lines = f.readlines()
        
    for pid in protein_ids:
        pockets = []
        count = 0
        num_label = len(label_dict[pid].split(";"))
        target_n = target_n_dict[pid] + args.extra_pockets
        for line in lines:
            if pid not in line: continue
            parts = line.strip().split()
            if len(parts) < 2: continue
            pocket = [f"{p.split('_')[0]}_{p.split('_')[-1]}" for p in parts[1].split(',')]
            if count < target_n:
                pockets.append(pocket)
                count += 1
        pred_dict[pid] = pockets

    # Evaluate
    print("Evaluating...")
    os.makedirs(os.path.dirname(args.out_csv), exist_ok=True)
    with open(args.out_csv, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["NO.", "PDB ID", "MCC", "DCA", "DCC", "Precision", "Recall", "F1"])
        for idx, pid in enumerate(protein_ids, 1):
            pdb_path = os.path.join(args.pdb_dir, f"{pid}.pdb")
            model = PDBParser(QUIET=True).get_structure("P", pdb_path)[0]
            label = extract_label(label_dict[pid])
            metrics = compute_metrics(pred_dict[pid], label, model)
            writer.writerow([idx, pid] + [metrics[k] for k in ("mcc", "dca", "dcc", "precision", "recall", "f1")])
    print(f"Results saved to: {args.out_csv}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate predicted binding pockets against ground truth labels.")
    parser.add_argument("--label_csv", required=True, help="Path to label CSV")
    parser.add_argument("--protein_list", required=True, help="File containing PDB IDs.")
    parser.add_argument("--pred_txt", required=True, help="Prediction file")
    parser.add_argument("--pdb_dir", required=True, help="Directory with PDB files to compute DCA and DCC.")
    parser.add_argument("--out_csv", required=True, help="Output path for evaluation results.")
    parser.add_argument("--extra_pockets", type=int, default=0, help="Number of extra pockets to include beyond top-n.")

    args = parser.parse_args()
    main(args)