From c7eefb1166738c4d9a65c631ab98a52e3a06b5a6 Mon Sep 17 00:00:00 2001 From: Hokori Date: Mon, 15 Jun 2026 00:30:55 +0800 Subject: [PATCH] =?UTF-8?q?fuck=20=E6=A0=B9=E6=9C=AC=E6=B2=A1=E6=94=B9?= =?UTF-8?q?=E7=AE=97=E5=8A=9B=E5=88=A4=E6=96=AD=EF=BC=8C=E6=88=91=E7=9C=9F?= =?UTF-8?q?=E6=98=AF=E6=93=8D=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- flash_qla/ops/gated_delta_rule/chunk/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_qla/ops/gated_delta_rule/chunk/__init__.py b/flash_qla/ops/gated_delta_rule/chunk/__init__.py index 1db0116..a2d8ee0 100644 --- a/flash_qla/ops/gated_delta_rule/chunk/__init__.py +++ b/flash_qla/ops/gated_delta_rule/chunk/__init__.py @@ -7,10 +7,10 @@ import tilelang from flash_qla.utils import l2norm from flash_qla.ops.utils import chunk_local_cumsum, group_reduce_vector -if tilelang.contrib.nvcc.get_target_compute_version() == "9.0": +if tilelang.contrib.nvcc.get_target_compute_version() == "7.5": from .hopper import fused_gdr_fwd, fused_gdr_bwd, fused_gdr_h, kkt_solve else: - raise ValueError("FlashQLA now support sm90 only.") + raise ValueError("FlashQLA now support sm75 only.") from .cp_context import intra_card_cp_preprocess