import pandas as pd
import numpy as np
import ta
from stable_baselines3 import PPO
from gym import Env, spaces
import os
import json
from stable_baselines3.common.callbacks import BaseCallback

class IterationStatsCallback(BaseCallback):
    def __init__(self, env, n_steps, verbose=0):
        super().__init__(verbose)
        self.env = env
        self.n_steps = n_steps
        self.iteration_count = 0

    def _on_step(self) -> bool:
        # Reinicia los contadores cada iteración
        if self.n_calls % self.n_steps == 0:
            self.iteration_count += 1
            print(f"\n🔄 Iteración {self.iteration_count} completada max_steps: {self.env.max_steps}")
            print(f"   - Operaciones abiertas: {self.env.total_open_operations}")
            print(f"   - Operaciones cerradas: {self.env.total_close_operations}")
            print(f"   - Ventas con ganancia: {self.env.total_profit_sales}")
            print(f"   - Ventas con pérdida: {self.env.total_loss_sales}")
            print(f"   - Total Stop Loss: {self.env.total_stop_loss}")

            # Reiniciar contadores
            self.env.total_open_operations = 0
            self.env.total_close_operations = 0
            self.env.total_profit_sales = 0
            self.env.total_loss_sales = 0
            self.env.total_stop_loss = 0

        return True

class TradingEnv(Env):
    def __init__(self, dataframe):
        super(TradingEnv, self).__init__()

        # Calcular indicadores técnicos
        dataframe['rsi'] = ta.momentum.RSIIndicator(dataframe['close'], window=14).rsi()
        dataframe['ema_short'] = ta.trend.EMAIndicator(dataframe['close'], window=9).ema_indicator()
        dataframe['ema_long'] = ta.trend.EMAIndicator(dataframe['close'], window=21).ema_indicator()
        dataframe['obv'] = ta.volume.OnBalanceVolumeIndicator(dataframe['close'], dataframe['volume']).on_balance_volume()
        dataframe['close_open_diff'] = dataframe['close'] - dataframe['open']
        dataframe['high_low_diff'] = dataframe['high'] - dataframe['low']
        dataframe['close_low_diff'] = dataframe['close'] - dataframe['low']
        dataframe['high_close_diff'] = dataframe['high'] - dataframe['close']
        dataframe['position'] = 0
        dataframe['macd'] = ta.trend.MACD(dataframe['close']).macd()
        dataframe['stoch_k'] = ta.momentum.StochasticOscillator(dataframe['close'], dataframe['high'], dataframe['low']).stoch()
        dataframe['cci'] = ta.trend.CCIIndicator(dataframe['high'], dataframe['low'], dataframe['close']).cci()

        # Remover filas con valores NaN
        dataframe.dropna(inplace=True)

        # Filtrar solo las columnas necesarias para el entorno
        expected_columns = ['close', 'volume', 'rsi', 'ema_short', 'ema_long', 'obv', 'close_open_diff', 'high_low_diff', 'close_low_diff', 'high_close_diff', 'position']
        self.data = dataframe[expected_columns]

        # Verificación del espacio de observación
        if not set(expected_columns).issubset(self.data.columns):
            raise ValueError(f"Error: El espacio de observación debe tener las siguientes 11 columnas: {expected_columns}, pero el dataframe tiene {self.data.columns.tolist()}.")

        self.current_step = 0
        self.max_steps = len(self.data) - 1

        # Espacio de observación
        self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float32)

        # Espacio de acción: 0 (Hold), 1 (Buy), 2 (Sell)
        self.action_space = spaces.Discrete(3)

        # Estado inicial
        self.balance = 1000
        self.position = 0
        self.entry_price = 0

    def reset(self):
        self.current_step = 0
        self.balance = 1000
        self.position = 0
        self.entry_price = 0
        
        return self._next_observation()

    def step(self, action):
        done = False
        reward = 0
        row = self.data.iloc[self.current_step]
        price = row['close']
        daily_volatility = row['high_low_diff'] / row['close'] * 100
        
        # Contadores de operaciones
        if not hasattr(self, 'total_open_operations'):
            self.total_open_operations = 0
            self.total_close_operations = 0
            self.total_profit_sales = 0
            self.total_loss_sales = 0
            self.total_stop_loss = 0
    
        # Parámetros ROI (ajusta según prefieras)
        roi_target = 1  # Porcentaje de ganancia objetivo
        max_holding_steps = 10  # Máximo tiempo de retención (~24 horas en velas de 5m)
    
        # Acción de Compra
        if action == 1 and self.position == 0:
            self.entry_price = price
            self.position = 1
            self.entry_step = self.current_step
            self.total_open_operations += 1
            reward += 2  # Incentivo para tomar riesgos al abrir posiciones
            
        # Acción de Venta o ROI Alcanzado
        elif self.position == 1:
            profit_percent = (price - self.entry_price) / self.entry_price * 100
            holding_duration = self.current_step - self.entry_step
            
            # Cerrar por ROI o tiempo máximo
            if profit_percent >= roi_target or holding_duration >= max_holding_steps:
                self.balance += profit_percent
                self.position = 0
                self.total_close_operations += 1
                
                if profit_percent > 0:
                    self.total_profit_sales += 1
                    reward += 5  # Recompensa alta para ventas positivas
                else:
                    self.total_loss_sales += 1
                    reward -= 5  # Penalización fuerte para ventas negativas
                
                #print(f"🏁 Venta cerrada con {profit_percent:.2f}% de profit en {holding_duration} pasos")
        
        # Penalización para operaciones largas
        if self.position == 1 and (self.current_step - self.entry_step) >= max_holding_steps:
            reward -= 10  # Penalización fuerte para operaciones largas
        
        # Recompensa por operar en días volátiles
        if daily_volatility > 2 and action != 0:
            reward += 0.2  # Recompensa adicional para incentivar operaciones en días volátiles
        
        # Penalización por inactividad
        if action == 0 and self.position == 0:
            reward -= 0.05  # Penalización leve para evitar la inactividad constante
    
        # Pasar al siguiente paso
        self.current_step += 1
        if self.current_step >= self.max_steps:
            done = True
            print(f"🏁 Fin del entorno alcanzado en paso {self.current_step} con balance final de {self.balance:.2f} USDC")
    
        # Estado siguiente
        obs = self._next_observation()
        return obs, reward, done, {}

    def _next_observation(self):
        row = self.data.iloc[self.current_step]
        return np.array([
            row['close'],
            row['volume'],
            row['rsi'],
            row['ema_short'],
            row['ema_long'],
            row['obv'],
            row['close_open_diff'],
            row['high_low_diff'],
            row['close_low_diff'],
            row['high_close_diff'],
            self.position
        ])

    def render(self):
        print(f"Step: {self.current_step}, Balance: {self.balance}, Position: {self.position}")

def train_model(dataframe, total_timesteps=100000):
    env = TradingEnv(dataframe)
    n_steps = 128  # Ajusta esto según tu configuración de PPO
    callback = IterationStatsCallback(env, n_steps)
    model = PPO(
        "MlpPolicy", 
        env, 
        verbose=1, 
        learning_rate=0.003,  # Tasa de aprendizaje más alta
        n_steps=64,  # Reduce para más frecuencia de actualización
        batch_size=64,  # Reduce para permitir más exploración
        gae_lambda=0.8,  # Reduce para hacer el modelo más reactivo
        gamma=0.85,  # Reduce para hacer el modelo más "corto-plazista"
        clip_range=0.2,  # Reduce para permitir más exploración
        ent_coef=0.001,
        vf_coef=0.5,
        max_grad_norm=0.5
    )
    model.learn(total_timesteps=total_timesteps, callback=callback)
    model.save("ppo_trading_model")
    print("Modelo entrenado y guardado como 'ppo_trading_model.zip'")

if __name__ == "__main__":
    # 📂 Cargar datos del archivo JSON
    with open("/root/freqtrade/user_data/data/binance/BTC_USDC-5m.json", "r") as file:
        raw_data = json.load(file)

    # Convertir los datos JSON a DataFrame con nombres de columna correctos
    df = pd.DataFrame(raw_data, columns=["timestamp", "open", "high", "low", "close", "volume"])
    
    # Convertir timestamp a datetime (opcional)
    df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
    # Imprimir la fecha de inicio y fin de los datos
    start_date = df['timestamp'].min()
    end_date = df['timestamp'].max()
    print(f"📅 Fecha de inicio: {start_date}")
    print(f"📅 Fecha de fin: {end_date}")
    print(f"📊 Total de velas: {len(df)}")

    # Remover filas con valores NaN
    df.dropna(inplace=True)

    # Verificación de las 11 columnas requeridas
    print("Columnas del dataframe:", df.columns)

    # Entrenar el modelo
    train_model(df, total_timesteps=50000)
