Skip to content

Instantly share code, notes, and snippets.

@jeffrey4l
Last active March 27, 2025 02:06
Show Gist options
  • Save jeffrey4l/0d91f505f1688d80310cbb171a55dd88 to your computer and use it in GitHub Desktop.
Save jeffrey4l/0d91f505f1688d80310cbb171a55dd88 to your computer and use it in GitHub Desktop.
patch for ktransformers to support nvidia v100 and t4, testd for commit 7a19f3b78116391b0c809285997e7b8df1ba5c0d
diff --git a/install.sh b/install.sh
index ffb7aca..c3730fd 100644
--- a/install.sh
+++ b/install.sh
@@ -11,5 +11,5 @@ echo "Installing python dependencies from requirements.txt"
pip install -r requirements-local_chat.txt
echo "Installing ktransformers"
-KTRANSFORMERS_FORCE_BUILD=TRUE pip install . --no-build-isolation
-echo "Installation completed successfully"
\ No newline at end of file
+CMAKE_ARGS="-DLLAMA_NATIVE=off" KTRANSFORMERS_FORCE_BUILD=TRUE pip install . --no-build-isolation
+echo "Installation completed successfully"
diff --git a/ktransformers/ktransformers_ext/cuda/setup.py b/ktransformers/ktransformers_ext/cuda/setup.py
index 156bb0e..1f13f95 100644
--- a/ktransformers/ktransformers_ext/cuda/setup.py
+++ b/ktransformers/ktransformers_ext/cuda/setup.py
@@ -13,14 +13,14 @@ setup(
# 'gptq_marlin_repack.cu',
],
extra_compile_args={
- 'cxx': ['-O3'],
+ 'cxx': ['-O3', '-D_GLIBCXX_USE_CXX11_ABI=1'],
'nvcc': [
'-O3',
'--use_fast_math',
- '-Xcompiler', '-fPIC',
+ '-Xcompiler', '-fPIC', '-D_GLIBCXX_USE_CXX11_ABI=1'
]
},
)
],
cmdclass={'build_ext': BuildExtension}
-)
\ No newline at end of file
+)
From 1d3f2ede5adebbd3a6fca0afa083545a68112574 Mon Sep 17 00:00:00 2001
From: Your Name <[email protected]>
Date: Thu, 27 Feb 2025 23:35:12 +0800
Subject: [PATCH] support v100
---
Dockerfile | 24 +++++++--------
ktransformers/local_chat.py | 10 +++----
ktransformers/operators/attention.py | 29 ++++++++++++++-----
.../DeepSeek-V2-Chat-multi-gpu-4.yaml | 10 +++----
.../DeepSeek-V2-Chat-multi-gpu.yaml | 6 ++--
.../optimize_rules/DeepSeek-V2-Chat.yaml | 4 +--
.../DeepSeek-V2-Lite-Chat-multi-gpu.yaml | 6 ++--
.../optimize_rules/DeepSeek-V2-Lite-Chat.yaml | 6 ++--
.../DeepSeek-V3-Chat-multi-gpu-4.yaml | 10 +++----
.../DeepSeek-V3-Chat-multi-gpu-8.yaml | 18 ++++++------
.../DeepSeek-V3-Chat-multi-gpu-marlin.yaml | 6 ++--
.../DeepSeek-V3-Chat-multi-gpu.yaml | 6 ++--
.../optimize_rules/DeepSeek-V3-Chat.yaml | 4 +--
.../optimize/optimize_rules/Mixtral.yaml | 4 +--
.../optimize_rules/Moonlight-16B-A3B.yaml | 4 +--
.../Qwen2-57B-A14B-Instruct-multi-gpu.yaml | 6 ++--
.../Qwen2-57B-A14B-Instruct.yaml | 4 +--
17 files changed, 85 insertions(+), 72 deletions(-)
diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py
index 7cbac7c..e4f5660 100644
--- a/ktransformers/local_chat.py
+++ b/ktransformers/local_chat.py
@@ -81,17 +81,17 @@ def local_chat(
print("using custom modeling_xxx.py.")
if (
"Qwen2Moe" in config.architectures[0]
- ): # Qwen2Moe must use flash_attention_2 to avoid overflow.
- config._attn_implementation = "flash_attention_2"
+ ): # Qwen2Moe must use eager to avoid overflow.
+ config._attn_implementation = "eager"
if "Llama" in config.architectures[0]:
config._attn_implementation = "eager"
if "Mixtral" in config.architectures[0]:
- config._attn_implementation = "flash_attention_2"
+ config._attn_implementation = "eager"
model = custom_models[config.architectures[0]](config)
else:
model = AutoModelForCausalLM.from_config(
- config, trust_remote_code=True, attn_implementation="flash_attention_2"
+ config, trust_remote_code=True, attn_implementation="eager"
)
if optimize_config_path is None:
@@ -180,4 +180,4 @@ def local_chat(
if __name__ == "__main__":
- fire.Fire(local_chat)
\ No newline at end of file
+ fire.Fire(local_chat)
diff --git a/ktransformers/operators/attention.py b/ktransformers/operators/attention.py
index 35c8093..0b84350 100644
--- a/ktransformers/operators/attention.py
+++ b/ktransformers/operators/attention.py
@@ -272,6 +272,13 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
print("position_ids", torch.isnan(position_ids).any())
"""
+ original_dtype = query_states.dtype
+ target_dtype = torch.half
+ query_states = query_states.to(target_dtype)
+ compressed_kv_with_k_pe = compressed_kv_with_k_pe.to(target_dtype)
+ compressed_kv = compressed_kv.to(target_dtype)
+ attn_output = attn_output.to(target_dtype)
+
# flash attn doesn't support head_dim bigger than 256
# use triton attention kernel adapted from vLLM and SGLang for MQA
decode_attention_fwd_grouped(query_states, compressed_kv_with_k_pe, compressed_kv, attn_output,
@@ -280,6 +287,7 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
4, #num_kv_splits # follow vLLM, fix it TODO
self.softmax_scale,
past_key_value.page_size)
+ attn_output = attn_output.to(original_dtype)
# attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
# out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
@@ -321,13 +329,20 @@ class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)
- attn_output = flash_attn_func(
- query_states,
- key_states,
- value_states_padded,
- softmax_scale=self.softmax_scale,
- causal=True,
- )
+ # attn_output = flash_attn_func(
+ # query_states,
+ # key_states,
+ # value_states_padded,
+ # softmax_scale=self.softmax_scale,
+ # causal=True,
+ # )
+ attn_output = F.scaled_dot_product_attention(
+ query_states.transpose(1, 2),
+ key_states.transpose(1, 2),
+ value_states_padded.transpose(1, 2),
+ scale=self.softmax_scale,
+ is_causal=True
+ ).transpose(1, 2)
if self.q_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
index 66a420a..173a6e0 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu-4.yaml
@@ -47,7 +47,7 @@
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([2][0-9]|[1][5-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
@@ -57,7 +57,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([3][0-9]|[4][0-4])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
@@ -67,7 +67,7 @@
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([5][0-9]|[4][5-9])\\.(?!self_attn\\.kv_b_proj).*$" # regular expression
@@ -77,7 +77,7 @@
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -228,7 +228,7 @@
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
index f409376..63b3ffa 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat-multi-gpu.yaml
@@ -31,7 +31,7 @@
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -42,7 +42,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -125,7 +125,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml
index 7f3e44e..85a3aeb 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Chat.yaml
@@ -13,7 +13,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -24,7 +24,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml
index 158892d..bb7891f 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat-multi-gpu.yaml
@@ -31,7 +31,7 @@
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -42,7 +42,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -125,7 +125,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
index 7f3e44e..d2c92d0 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V2-Lite-Chat.yaml
@@ -13,7 +13,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -24,7 +24,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -65,4 +65,4 @@
class: "default"
kwargs:
generate_device: "cpu"
- prefill_device: "cpu"
\ No newline at end of file
+ prefill_device: "cpu"
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
index ea75b30..25e6d05 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
@@ -59,7 +59,7 @@
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 1: layers 15–29
@@ -71,7 +71,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 2: layers 30–44
@@ -83,7 +83,7 @@
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 3: layers 45–60
@@ -95,7 +95,7 @@
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# === MLP (MoE) Replacement ===
@@ -375,7 +375,7 @@
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# For final modules (model.norm), ensure they are on GPU 3 (as in your original config)
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml
index b00d2b4..e746680 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-8.yaml
@@ -100,7 +100,7 @@
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 1: layers 8–15
@@ -112,7 +112,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 2: layers 16–23
@@ -124,7 +124,7 @@
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 3: layers 24–31
@@ -136,7 +136,7 @@
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 4: layers 32–39
@@ -148,7 +148,7 @@
kwargs:
generate_device: "cuda:4"
prefill_device: "cuda:4"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 5: layers 40–47
@@ -160,7 +160,7 @@
kwargs:
generate_device: "cuda:5"
prefill_device: "cuda:5"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 6: layers 48–55
@@ -172,7 +172,7 @@
kwargs:
generate_device: "cuda:6"
prefill_device: "cuda:6"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# GPU 7: layers 56–63
@@ -184,7 +184,7 @@
kwargs:
generate_device: "cuda:7"
prefill_device: "cuda:7"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
@@ -721,7 +721,7 @@
kwargs:
generate_device: "cuda:7"
prefill_device: "cuda:7"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
# For final modules (model.norm), ensure they are on GPU 7 (as in your original config)
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
index e04c6ce..0fca38c 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
@@ -31,7 +31,7 @@
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -42,7 +42,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -160,7 +160,7 @@
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
index 50e282d..88174ea 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu.yaml
@@ -31,7 +31,7 @@
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -42,7 +42,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -142,7 +142,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
diff --git a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
index d28e016..f0f8718 100644
--- a/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
+++ b/ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat.yaml
@@ -14,7 +14,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -25,7 +25,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
diff --git a/ktransformers/optimize/optimize_rules/Mixtral.yaml b/ktransformers/optimize/optimize_rules/Mixtral.yaml
index 80a346a..a8705ac 100644
--- a/ktransformers/optimize/optimize_rules/Mixtral.yaml
+++ b/ktransformers/optimize/optimize_rules/Mixtral.yaml
@@ -13,7 +13,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
@@ -23,7 +23,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.block_sparse_moe$"
diff --git a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
index 6cea246..dc0fd6a 100644
--- a/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
+++ b/ktransformers/optimize/optimize_rules/Moonlight-16B-A3B.yaml
@@ -14,7 +14,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
@@ -25,7 +25,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
index da01c82..caba1e1 100644
--- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
+++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct-multi-gpu.yaml
@@ -14,7 +14,7 @@
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([012])\\.mlp$"
@@ -50,7 +50,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\.mlp$"
@@ -85,7 +85,7 @@
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
diff --git a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
index 38e9e73..b12f022 100644
--- a/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
+++ b/ktransformers/optimize/optimize_rules/Qwen2-57B-A14B-Instruct.yaml
@@ -13,7 +13,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^lm_head"
@@ -23,7 +23,7 @@
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- generate_op: "KLinearMarlin"
+ generate_op: "KLinearTorch"
prefill_op: "KLinearTorch"
- match:
name: "^model\\.layers\\..*\\.mlp$"
--
2.34.1
@jeffrey4l
Copy link
Author

patch come from issue kvcache-ai/ktransformers#425

@staskikotx
Copy link

Thank you for a great job. But could you add some details on the models you tested this fix ? I've just tried it on DeepSeek-V2-Lite-Chat.Q4_K_M.gguf and Tesla V100, but it gives Nans as output logits. Should the patched version work with this model ? Or it works only on models without quantization and I need to write some additional code to dequantize this one?

@jeffrey4l
Copy link
Author

@staskikotx what's exact error you hit? this should works for V2/V3/R1 models.

@staskikotx
Copy link

Here is the mistake.

Chat: Who are you?
../aten/src/ATen/native/cuda/TensorCompare.cu:110: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion probability tensor contains either inf, nan or element < 0 failed.
Traceback (most recent call last):
File "/workspace/ktransformers_test.py", line 9, in
local_chat.local_chat(model_path=model_path, gguf_path=gguf_path, chunk_prefill_size=chunk_prefill_size)
File "/opt/conda/lib/python3.11/site-packages/ktransformers/local_chat.py", line 181, in local_chat
generated = prefill_and_generate(
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/ktransformers/util/utils.py", line 214, in prefill_and_generate
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: device-side assert triggered

I checked, there are indeed NaNs in the logits and probs variables.

I could've messed up something while applying the patch. Or should I add a couple of lines to dequantize the model I am using explicitly ?

@jeffrey4l
Copy link
Author

@staskikotx i pushed my code to https://github.com/jeffrey4l/ktransformers/tree/support_t4, cloud you have a try ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment