diff --git a/flash_qla/ops/gated_delta_rule/chunk/__init__.py b/flash_qla/ops/gated_delta_rule/chunk/__init__.py index 1db0116..a2d8ee0 100644 --- a/flash_qla/ops/gated_delta_rule/chunk/__init__.py +++ b/flash_qla/ops/gated_delta_rule/chunk/__init__.py @@ -7,10 +7,10 @@ import tilelang from flash_qla.utils import l2norm from flash_qla.ops.utils import chunk_local_cumsum, group_reduce_vector -if tilelang.contrib.nvcc.get_target_compute_version() == "9.0": +if tilelang.contrib.nvcc.get_target_compute_version() == "7.5": from .hopper import fused_gdr_fwd, fused_gdr_bwd, fused_gdr_h, kkt_solve else: - raise ValueError("FlashQLA now support sm90 only.") + raise ValueError("FlashQLA now support sm75 only.") from .cp_context import intra_card_cp_preprocess