import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import pandas as pd
import numpy as np
import ta
import argparse
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score
import joblib

# Argumentos
parser = argparse.ArgumentParser()
parser.add_argument("--pair", required=True, help="Par en formato WCT/USDC")
parser.add_argument("--seq-len", type=int, default=20, help="Longitud de la ventana (n velas hacia atrás)")
parser.add_argument("--fwd", type=int, default=1, help="Ventanas a futuro para el target")
parser.add_argument("--fwd-profit-up", type=int, default=36, help="Ventanas a futuro para el target de profit-up (ej: 36 min)")
parser.add_argument("--profit-up", type=float, default=0.38, help="Porcentaje de subida para el target de profit-up (ej: 0.38)")
args = parser.parse_args()

SEQ_LEN = args.seq_len
FWD = args.fwd
FWD_PROFIT_UP = args.fwd_profit_up
PROFIT_UP = args.profit_up / 100  # lo pasamos a factor

pair_filename = args.pair.replace("/", "_")
data_path = f"/root/freqtrade/user_data/data/binance/{pair_filename}-1m.feather"
if not os.path.isfile(data_path):
    raise FileNotFoundError(f"No se encontró el archivo: {data_path}")

# === Cargar datos ===
df = pd.read_feather(data_path)
df = df.sort_values("date").reset_index(drop=True)

# === Features avanzados (igual que antes) ===
df["rsi"] = ta.momentum.RSIIndicator(df["close"], window=14).rsi()
df["rsi_7"] = ta.momentum.RSIIndicator(df["close"], window=7).rsi()
df["rsi_21"] = ta.momentum.RSIIndicator(df["close"], window=21).rsi()
df["ema_10"] = ta.trend.EMAIndicator(df["close"], window=10).ema_indicator()
df["ema_21"] = ta.trend.EMAIndicator(df["close"], window=21).ema_indicator()
df["ema_50"] = ta.trend.EMAIndicator(df["close"], window=50).ema_indicator()
macd = ta.trend.MACD(df["close"])
df["macd"] = macd.macd()
df["macd_signal"] = macd.macd_signal()
df["macd_diff"] = macd.macd_diff()
df["price_change"] = df["close"].pct_change().clip(-1, 1)
df["volume_change"] = df["volume"].pct_change().clip(-1, 1)
df["rsi_change"] = df["rsi"].diff().clip(-1, 1)
df["ema_diff_10"] = df["close"] - df["ema_10"]
df["ema_diff_21"] = df["close"] - df["ema_21"]
df["ema_diff_50"] = df["close"] - df["ema_50"]
df["upper_shadow"] = df["high"] - np.maximum(df["close"], df["open"])
df["lower_shadow"] = np.minimum(df["close"], df["open"]) - df["low"]
df["body_size"] = np.abs(df["close"] - df["open"])
bb = ta.volatility.BollingerBands(df["close"])
df["boll_ub"] = bb.bollinger_hband()
df["boll_lb"] = bb.bollinger_lband()
df["boll_dist"] = df["close"] - bb.bollinger_mavg()
df["atr"] = ta.volatility.AverageTrueRange(df["high"], df["low"], df["close"]).average_true_range()
df["stoch_rsi"] = ta.momentum.StochRSIIndicator(df["close"]).stochrsi()
df["roc"] = ta.momentum.ROCIndicator(df["close"], window=10).roc()

# === Targets: N velas adelante ===
df["future_close"] = df["close"].shift(-FWD)
df["future_rsi"] = df["rsi"].shift(-FWD)
df["target_price_up"] = (df["future_close"] > df["close"]).astype(np.float32)
df["target_rsi_up"] = (df["future_rsi"] > df["rsi"]).astype(np.float32)
# --- Nuevo target ---
df["future_close_profit"] = df["close"].shift(-FWD_PROFIT_UP)
df["target_profit_up"] = ((df["future_close_profit"] > df["close"] * (1 + PROFIT_UP))).astype(np.float32)

print(f"SEQ_LEN={SEQ_LEN}, FWD={FWD}")
print("Proporción target_price_up:", df["target_price_up"].mean())
print("Proporción target_rsi_up:", df["target_rsi_up"].mean())
print("Proporción target_profit_up:", df["target_profit_up"].mean())
print("Total:", len(df))

# Limpieza
df = df.fillna(0)
df.reset_index(drop=True, inplace=True)

# === Features y normalización ===
features = [
    "open", "high", "low", "close", "volume",
    "rsi", "rsi_7", "rsi_21",
    "ema_10", "ema_21", "ema_50",
    "macd", "macd_signal", "macd_diff",
    "price_change", "volume_change", "rsi_change",
    "ema_diff_10", "ema_diff_21", "ema_diff_50",
    "upper_shadow", "lower_shadow", "body_size",
    "boll_ub", "boll_lb", "boll_dist",
    "atr", "stoch_rsi", "roc"
]
X = df[features].astype(np.float32).values
y_price = df["target_price_up"].astype(np.float32).values
y_rsi = df["target_rsi_up"].astype(np.float32).values
y_profit_up = df["target_profit_up"].astype(np.float32).values

scaler = RobustScaler()
X_scaled = scaler.fit_transform(X)

# === Secuencias ===
X_clean = []
y_price_clean = []
y_rsi_clean = []
y_profit_up_clean = []
for i in range(len(X_scaled) - SEQ_LEN):
    x_seq = X_scaled[i:i+SEQ_LEN]
    yp = y_price[i+SEQ_LEN]
    yr = y_rsi[i+SEQ_LEN]
    ypu = y_profit_up[i+SEQ_LEN]
    if np.isfinite(x_seq).all() and np.isfinite(yp) and np.isfinite(yr) and np.isfinite(ypu):
        X_clean.append(x_seq)
        y_price_clean.append(yp)
        y_rsi_clean.append(yr)
        y_profit_up_clean.append(ypu)
X_seq = np.array(X_clean, dtype=np.float32)
y_price_seq = np.array(y_price_clean, dtype=np.float32)
y_rsi_seq = np.array(y_rsi_clean, dtype=np.float32)
y_profit_up_seq = np.array(y_profit_up_clean, dtype=np.float32)

print("X shape:", X_seq.shape)
print("Clases presentes en y_price:", np.unique(y_price_seq, return_counts=True))
print("Clases presentes en y_rsi:", np.unique(y_rsi_seq, return_counts=True))
print("Clases presentes en y_profit_up:", np.unique(y_profit_up_seq, return_counts=True))

# === Split ===
X_train, X_test, y_price_train, y_price_test, y_rsi_train, y_rsi_test, y_profit_up_train, y_profit_up_test = train_test_split(
    X_seq, y_price_seq, y_rsi_seq, y_profit_up_seq, test_size=0.2, shuffle=False
)

# === Modelo LSTM triple salida ===
inp = Input(shape=(SEQ_LEN, X_seq.shape[2]))
x = LSTM(64, return_sequences=False)(inp)
x = Dropout(0.2)(x)
out_price = Dense(1, activation='sigmoid', name='out_price')(x)
out_rsi = Dense(1, activation='sigmoid', name='out_rsi')(x)
out_profit_up = Dense(1, activation='sigmoid', name='out_profit_up')(x)

model = tf.keras.models.Model(inputs=inp, outputs=[out_price, out_rsi, out_profit_up])
model.compile(
    loss={'out_price': 'binary_crossentropy', 'out_rsi': 'binary_crossentropy', 'out_profit_up': 'binary_crossentropy'},
    optimizer=Adam(learning_rate=0.0005),
    metrics={'out_price': 'accuracy', 'out_rsi': 'accuracy', 'out_profit_up': 'accuracy'}
)

early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

history = model.fit(
    X_train, [y_price_train, y_rsi_train, y_profit_up_train],
    epochs=50,
    batch_size=64,
    validation_split=0.2,
    callbacks=[early_stop]
)

# === Predicciones ===
y_price_pred, y_rsi_pred, y_profit_up_pred = model.predict(X_test)

# === Evaluación ===
for name, y_true, y_pred in [
    ('Precio sube', y_price_test, y_price_pred),
    ('RSI sube', y_rsi_test, y_rsi_pred),
    ('Profit-up', y_profit_up_test, y_profit_up_pred)
]:
    print(f"\n--- Evaluación: {name} ---")
    for th in [0.1, 0.2, 0.3, 0.4, 0.5]:
        yp_bin = (y_pred > th).astype(int)
        print(f"\nTHRESHOLD {th}:")
        print(classification_report(y_true, yp_bin, digits=3))
        print(confusion_matrix(y_true, yp_bin))
        print("ROC AUC:", roc_auc_score(y_true, y_pred))

# === Visualiza la distribución de probabilidades ===
import matplotlib.pyplot as plt
plt.figure(figsize=(10,4))
plt.hist(y_price_pred, bins=50, alpha=0.5, label='Prob Precio sube')
plt.hist(y_rsi_pred, bins=50, alpha=0.5, label='Prob RSI sube')
plt.hist(y_profit_up_pred, bins=50, alpha=0.5, label='Prob Profit-up')
plt.title("Distribución de probabilidades predichas")
plt.xlabel("Probabilidad predicha")
plt.ylabel("Frecuencia")
plt.legend()
plt.show()

# === Guardar modelo ===
model.save(f"modelo_lstm_triple_{pair_filename}_seq{SEQ_LEN}_fwd{FWD}_profit{FWD_PROFIT_UP}_{args.profit_up}.h5")
print(f"[INFO] Modelo LSTM guardado como: modelo_lstm_triple_{pair_filename}_seq{SEQ_LEN}_fwd{FWD}_profit{FWD_PROFIT_UP}_{args.profit_up}.h5")
scaler_path = f"scaler_{pair_filename}_seq{SEQ_LEN}_fwd{FWD}.gz"
joblib.dump(scaler, scaler_path)
print(f"[INFO] Scaler guardado como: {scaler_path}")