Source code for bbstrader.core.strategy

from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pandas as pd
import pytz
from loguru import logger

from bbstrader.config import BBSTRADER_DIR
from bbstrader.models.optimization import optimized_weights

logger.add(
    f"{BBSTRADER_DIR}/logs/strategy.log",
    enqueue=True,
    level="INFO",
    format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name} | {message}",
)


__all__ = [
    "TradeAction",
    "TradeSignal",
    "TradingMode",
    "generate_signal",
    "Strategy",
]


[docs] class TradeAction(Enum): """ An enumeration class for trade actions. """ BUY = "LONG" SELL = "SHORT" LONG = "LONG" SHORT = "SHORT" BMKT = "BMKT" SMKT = "SMKT" BLMT = "BLMT" SLMT = "SLMT" BSTP = "BSTP" SSTP = "SSTP" BSTPLMT = "BSTPLMT" SSTPLMT = "SSTPLMT" EXIT = "EXIT" EXIT_LONG = "EXIT_LONG" EXIT_SHORT = "EXIT_SHORT" EXIT_STOP = "EXIT_STOP" EXIT_LIMIT = "EXIT_LIMIT" EXIT_LONG_STOP = "EXIT_LONG_STOP" EXIT_LONG_LIMIT = "EXIT_LONG_LIMIT" EXIT_SHORT_STOP = "EXIT_SHORT_STOP" EXIT_SHORT_LIMIT = "EXIT_SHORT_LIMIT" EXIT_LONG_STOP_LIMIT = "EXIT_LONG_STOP_LIMIT" EXIT_SHORT_STOP_LIMIT = "EXIT_SHORT_STOP_LIMIT" EXIT_PROFITABLES = "EXIT_PROFITABLES" EXIT_LOSINGS = "EXIT_LOSINGS" EXIT_ALL_POSITIONS = "EXIT_ALL_POSITIONS" EXIT_ALL_ORDERS = "EXIT_ALL_ORDERS" def __str__(self): return self.value
[docs] @dataclass() class TradeSignal: """ Represents a trading signal generated by a trading system or strategy. Notes ----- Attributes: - id (int): A unique identifier for the trade signal or the strategy. - symbol (str): The trading symbol (e.g., stock ticker, forex pair, crypto asset). - action (TradeAction): The trading action to perform. Must be an instance of the ``TradeAction`` enum (e.g., BUY, SELL). - price (float, optional): The price at which the trade should be executed. - stoplimit (float, optional): A stop-limit price for the trade. Must not be set without specifying a price. - sl (float, optional): A stop loss price for the trade. - tp (float, optional): A take profit price for the trade. - comment (str, optional): An optional comment or description related to the trade signal. """ id: int symbol: str action: TradeAction price: float = None stoplimit: float = None sl: float = None tp: float = None comment: str = None def __post_init__(self): if not isinstance(self.action, TradeAction): raise TypeError( f"action must be of type TradeAction, not {type(self.action)}" ) if self.stoplimit is not None and self.price is None: raise ValueError("stoplimit cannot be set without price") def __repr__(self): return ( f"TradeSignal(id={self.id}, symbol='{self.symbol}', action='{self.action.value}', " f"price={self.price}, stoplimit={self.stoplimit}, sl={self.sl}, tp={self.tp}, comment='{self.comment or ''}')" )
[docs] def generate_signal( id: int, symbol: str, action: TradeAction, price: float = None, stoplimit: float = None, sl: float = None, tp: float = None, comment: str = None, ) -> TradeSignal: """ Generates a trade signal for MetaTrader 5. Args: id (int): Unique identifier for the trade signal. symbol (str): The symbol for which the trade signal is generated. action (TradeAction): The action to be taken (e.g., BUY, SELL). price (float, optional): The price at which to execute the trade. stoplimit (float, optional): The stop limit price for the trade. sl (float, optional): The stop loss price for the trade. tp (float, optional): The take profit price for the trade. comment (str, optional): Additional comments for the trade. Returns: TradeSignal: A TradeSignal object containing the details of the trade signal. """ return TradeSignal( id=id, symbol=symbol, action=action, price=price, stoplimit=stoplimit, sl=sl, tp=tp, comment=comment, )
[docs] class TradingMode(Enum): BACKTEST = "BACKTEST" LIVE = "LIVE"
[docs] def isbacktest(self) -> bool: return self == TradingMode.BACKTEST
[docs] def islive(self) -> bool: return self == TradingMode.LIVE
[docs] class Strategy(metaclass=ABCMeta): """ A `Strategy()` object encapsulates all calculation on market data that generate advisory signals to a `Portfolio` object. Thus all of the "strategy logic" resides within this class. We opted to separate out the `Strategy` and `Portfolio` objects for this backtester, since we believe this is more amenable to the situation of multiple strategies feeding "ideas" to a larger `Portfolio`, which then can handle its own risk (such as sector allocation, leverage). In higher frequency trading, the strategy and portfolio concepts will be tightly coupled and extremely hardware dependent. At this stage in the event-driven backtester development there is no concept of an indicator or filter, such as those found in technical trading. These are also good candidates for creating a class hierarchy. The strategy hierarchy is relatively simple as it consists of an abstract base class with a single pure virtual method for generating `SignalEvent` objects. Other methods are provided to check for pending orders, update trades from fills, and get updates from the portfolio. """
[docs] @abstractmethod def calculate_signals(self, *args: Any, **kwargs: Any) -> List[TradeSignal]: raise NotImplementedError("Should implement calculate_signals()")
[docs] def check_pending_orders(self, *args: Any, **kwargs: Any) -> None: ...
[docs] def get_update_from_portfolio(self, *args: Any, **kwargs: Any) -> None: ...
[docs] def update_trades_from_fill(self, *args: Any, **kwargs: Any) -> None: ...
[docs] def perform_period_end_checks(self, *args: Any, **kwargs: Any) -> None: ...
class BaseStrategy(Strategy): """ Base class containing shared logic for both Backtest and Live MT5 strategies. This class handles configuration, logging, and common utility calculations. """ tf: str id: int ID: int max_trades: Dict[str, int] risk_budget: Optional[Union[Dict[str, float], str]] symbols: List[str] logger: "logger" # type: ignore kwargs: Dict[str, Any] periodes: int NAME: str DESCRIPTION: str def __init__( self, symbol_list: List[str], **kwargs: Any, ) -> None: self.symbols = symbol_list self.risk_budget = self._check_risk_budget(**kwargs) self.max_trades = kwargs.get("max_trades", {s: 1 for s in self.symbols}) self.tf = kwargs.get("time_frame", "D1") self.logger = kwargs.get("logger") or logger self.kwargs = kwargs self.periodes = 0 def _check_risk_budget( self, **kwargs: Any ) -> Optional[Union[Dict[str, float], str]]: weights = kwargs.get("risk_weights") if weights is not None and isinstance(weights, dict): for asset in self.symbols: if asset not in weights: raise ValueError(f"Risk budget for asset {asset} is missing.") total_risk = float(round(sum(weights.values()))) if not np.isclose(total_risk, 1.0): raise ValueError(f"Risk budget weights must sum to 1. got {total_risk}") return weights elif isinstance(weights, str): return weights return None @property @abstractmethod def cash(self) -> float: """Returns the available cash (virtual or real).""" raise NotImplementedError def calculate_signals(self, *args: Any, **kwargs: Any) -> List[TradeSignal]: """ Provides the mechanisms to calculate signals for the strategy. This methods should return a list of signals for the strategy. Each signal must be a ``TradeSignal`` object with the following attributes: - ``action``: The order to execute on the symbol (LONG, SHORT, EXIT, etc.), see `bbstrader.core.utils.TradeAction`. - ``price``: The price at which to execute the action, used for pending orders. - ``stoplimit``: The stop-limit price for STOP-LIMIT orders, used for pending stop limit orders. - ``id``: The unique identifier for the strategy or order. - ``comment``: An optional comment or description related to the trade signal. """ raise NotImplementedError("Should implement calculate_signals()") def perform_period_end_checks(self, *args: Any, **kwargs: Any) -> None: """ Some strategies may require additional checks at the end of the period, such as closing all positions or orders or tracking the performance of the strategy etc. This method is called at the end of the period to perform such checks. """ pass @abstractmethod def get_asset_values( self, symbol_list: List[str], window: int, value_type: str = "returns", array: bool = True, **kwargs, ) -> Optional[Dict[str, Union[np.ndarray, pd.Series]]]: """ Get the historical OHLCV value or returns or custum value based on the DataHandker of the assets in the symbol list. Args: bars : DataHandler for market data handling, required for backtest mode. symbol_list : List of ticker symbols for the pairs trading strategy. value_type : The type of value to get (e.g., returns, open, high, low, close, adjclose, volume). array : If True, return the values as numpy arrays, otherwise as pandas Series. mode : Mode of operation for the strategy. window : The lookback period for resquesting the data. tf : The time frame for the strategy. error : The error handling method for the function. Returns: asset_values : Historical values of the assets in the symbol list. Note: In Live mode, the `bbstrader.metatrader.rates.Rates` class is used to get the historical data so the value_type must be 'returns', 'open', 'high', 'low', 'close', 'adjclose', 'volume'. """ raise NotImplementedError def apply_risk_management( self, optimizer: str, symbols: Optional[List[str]] = None, freq: int = 252, ) -> Optional[Dict[str, float]]: """Apply risk management optimization.""" if optimizer is None: return None symbols = symbols or self.symbols prices = self.get_asset_values( symbol_list=symbols, window=freq, value_type="close", array=False, ) if prices is None: return None prices = pd.DataFrame(prices) prices = prices.dropna(axis=0, how="any") try: weights = optimized_weights(prices=prices, freq=freq, method=optimizer) return {symbol: abs(weight) for symbol, weight in weights.items()} except Exception: return {symbol: 0.0 for symbol in symbols} def get_quantity( self, symbol: str, weight: float, price: Optional[float] = None, volume: Optional[float] = None, maxqty: Optional[int] = None, ) -> int: """ Calculate the quantity to buy or sell for a given symbol based on the dollar value provided. The quantity calculated can be used to evalute a strategy's performance for each symbol given the fact that the dollar value is the same for all symbols. Args: symbol : The symbol for the trade. Returns: qty : The quantity to buy or sell for the symbol. """ current_cash = self.cash if ( current_cash is None or weight == 0 or current_cash == 0 or np.isnan(current_cash) ): return 0 if price is None: vals = self.get_asset_values( [symbol], window=1, value_type="close", array=True ) if vals and symbol in vals and len(vals[symbol]) > 0: price = float(vals[symbol][-1]) else: price = None if volume is None: volume = round(current_cash * weight) if ( price is None or not isinstance(price, (int, float, np.number)) or volume is None or not isinstance(volume, (int, float, np.number)) or np.isnan(float(price)) or np.isnan(float(volume)) ): if weight != 0: return 1 return 0 qty = round(volume / price, 2) qty = max(qty, 0) / self.max_trades.get(symbol, 1) if maxqty is not None: qty = min(qty, maxqty) return int(max(round(qty, 2), 0)) def get_quantities( self, quantities: Optional[Union[Dict[str, int], int]] ) -> Dict[str, Optional[int]]: """ Get the quantities to buy or sell for the symbols in the strategy. This method is used when whe need to assign different quantities to the symbols. Args: quantities : The quantities for the symbols in the strategy. """ if quantities is None: return {symbol: None for symbol in self.symbols} if isinstance(quantities, dict): return quantities elif isinstance(quantities, int): return {symbol: quantities for symbol in self.symbols} raise TypeError(f"Unsupported type for quantities: {type(quantities)}") @staticmethod def calculate_pct_change(current_price: float, lh_price: float) -> float: return ((current_price - lh_price) / lh_price) * 100 @staticmethod def is_signal_time(period_count: Optional[int], signal_inverval: int) -> bool: """ Check if we can generate a signal based on the current period count. We use the signal interval as a form of periodicity or rebalancing period. Args: period_count : The current period count (e.g., number of bars). signal_inverval : The signal interval for generating signals (e.g., every 5 bars). Returns: bool : True if we can generate a signal, False otherwise """ if period_count == 0 or period_count is None: return True return period_count % signal_inverval == 0 @staticmethod def get_current_dt(time_zone: str = "US/Eastern") -> datetime: return datetime.now(pytz.timezone(time_zone)) @staticmethod def convert_time_zone( dt: Union[datetime, int, pd.Timestamp], from_tz: str = "UTC", to_tz: str = "US/Eastern", ) -> pd.Timestamp: """ Convert datetime from one timezone to another. Args: dt : The datetime to convert. from_tz : The timezone to convert from. to_tz : The timezone to convert to. Returns: dt_to : The converted datetime. """ from_tz_pytz = pytz.timezone(from_tz) if isinstance(dt, (datetime, int)): dt_ts = pd.to_datetime(dt, unit="s") else: dt_ts = dt if dt_ts.tzinfo is None: dt_ts = dt_ts.tz_localize(from_tz_pytz) else: dt_ts = dt_ts.tz_convert(from_tz_pytz) return dt_ts.tz_convert(pytz.timezone(to_tz)) @staticmethod def stop_time(time_zone: str, stop_time: str) -> bool: now = datetime.now(pytz.timezone(time_zone)).time() stop_time_dt = datetime.strptime(stop_time, "%H:%M").time() return now >= stop_time_dt class TWSStrategy(Strategy): def calculate_signals(self, *args: Any, **kwargs: Any) -> List[TradeSignal]: raise NotImplementedError("Should implement calculate_signals()")