George-Polya opened a new issue, #17647:
URL: https://github.com/apache/tvm/issues/17647

   I try to quantizae llava-v1.6-34b 
   ```
   python3 -m mlc_llm.build --model /data/models/mlc/dist/models/llava-v1.6-34b 
\ 
        --quantization q4f16_ft  \
        --target cuda \
        --use-cuda-graph \
        --use-flash-attn-mqa \ 
        --sep-embed \ 
        --max-seq-len 256
        --artifact-path /data/models/mlc/dist/llava-v1.6-34b/ctx256 \ 
        --use-safetensors
   ```
   
   ### Expected behavior
   
   `llava-v1.6-34b-q4f16_ft-cuda.so` file is created in 
`/data/models/mlc/dist/llava-v1.5-13b/ctx256/llava-v1.6-34b-q4f16_ft/`
   
(`/data/models/mlc/dist/llava-v1.5-13b/ctx256/llava-v1.6-34b-q4f16_ft/llava-v1.6-34b-q4f16_ft-cuda.so`)
   
   ### Actual behavior
   
   ```
   Traceback (most recent call last):
     File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
       return _run_code(code, main_globals, None,
     File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
       exec(code, run_globals)
     File "/usr/local/lib/python3.10/dist-packages/mlc_llm/build.py", line 47, 
in <module>
       main()
     File "/usr/local/lib/python3.10/dist-packages/mlc_llm/build.py", line 43, 
in main
       core.build_model_from_args(parsed_args)
     File "/usr/local/lib/python3.10/dist-packages/mlc_llm/core.py", line 961, 
in build_model_from_args
       mod = mod_transform_before_build(mod, param_manager, args, model_config)
     File "/usr/local/lib/python3.10/dist-packages/mlc_llm/core.py", line 613, 
in mod_transform_before_build
       mod = fuse_split_rotary_embedding(
     File "/usr/local/lib/python3.10/dist-packages/tvm/ir/transform.py", line 
238, in __call__
       return _ffi_transform_api.RunPass(self, mod)
     File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in 
tvm._ffi._cy3.core.PackedFuncBase.__call__
     File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in 
tvm._ffi._cy3.core.FuncCall
     File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in 
tvm._ffi._cy3.core.FuncCall3
     File "tvm/_ffi/_cython/./base.pxi", line 185, in 
tvm._ffi._cy3.core.CHECK_CALL
     File "/usr/local/lib/python3.10/dist-packages/tvm/_ffi/base.py", line 481, 
in raise_last_ffi_error
       raise py_err
     File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in 
tvm._ffi._cy3.core.tvm_callback
     File 
"/usr/local/lib/python3.10/dist-packages/mlc_llm/transform/fuse_split_rotary_embedding.py",
 line 118, in ir_module_pass
       split_rotary = get_dynamic_split_rotary()
     File 
"/usr/local/lib/python3.10/dist-packages/mlc_llm/transform/fuse_split_rotary_embedding.py",
 line 100, in get_dynamic_split_rotary
       relax.expr._update_struct_info(
     File "/usr/local/lib/python3.10/dist-packages/tvm/relax/expr.py", line 
1224, in _update_struct_info
       _ffi_api.UpdateStructInfo(expr, struct_info)  # type: ignore
     File "tvm/_ffi/_cython/./packed_func.pxi", line 339, in 
tvm._ffi._cy3.core.PackedFuncBase.__call__
     File "tvm/_ffi/_cython/./packed_func.pxi", line 270, in 
tvm._ffi._cy3.core.FuncCall
     File "tvm/_ffi/_cython/./packed_func.pxi", line 259, in 
tvm._ffi._cy3.core.FuncCall3
     File "tvm/_ffi/_cython/./base.pxi", line 185, in 
tvm._ffi._cy3.core.CHECK_CALL
     File "/usr/local/lib/python3.10/dist-packages/tvm/_ffi/base.py", line 481, 
in raise_last_ffi_error
       raise py_err
   tvm.error.InternalError: Traceback (most recent call last):
     [bt] (5) 
/usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(TVMFuncCall+0x68) 
[0xffff73860f98]
     [bt] (4) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1944cf4) 
[0xffff720a4cf4]
     [bt] (3) 
/usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::relax::UpdateStructInfo(tvm::RelayExpr,
 tvm::relax::StructInfo)+0x1bc) [0xffff7209fd10]
     [bt] (2) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x193f8ec) 
[0xffff7209f8ec]
     [bt] (1) 
/usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0x68)
 [0xffff7199c7f8]
     [bt] (0) 
/usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x30)
 [0xffff738abfc0]
     File "/opt/mlc-llm/3rdparty/tvm/src/relax/ir/struct_info.cc", line 211
   InternalError: Check failed: (!expr->struct_info_.defined()) is false: To 
ensure idempotency, the expression passed to UpdateStructInfo must not have any 
prior StructInfo.  However, expression # from tvm.script import tir as T
   @T.prim_func(private=True)
   def main(fused_qkv_handle: T.handle, embedded_query_handle: T.handle, 
embedded_key_handle: T.handle, value_handle: T.handle, rotary_offset: T.int64, 
batch_size: T.int64, seq_len: T.int64, num_query_heads: T.int64, num_kv_heads: 
T.int64, head_dim: T.int64, position_embedding_base: T.float32):
       T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)})
       Fused_QKV = T.match_buffer(fused_qkv_handle, (batch_size, seq_len, 
num_query_heads + num_kv_heads * T.int64(2), head_dim), "float16")
       EmbeddedQuery = T.match_buffer(embedded_query_handle, (batch_size, 
seq_len, num_query_heads, head_dim), "float16")
       EmbeddedKey = T.match_buffer(embedded_key_handle, (batch_size, seq_len, 
num_kv_heads, head_dim), "float16")
       Value = T.match_buffer(value_handle, (batch_size, seq_len, num_kv_heads, 
head_dim), "float16")
       # with T.block("root"):
       for iters_0, iters_1, iters_2, iters_3 in T.grid(batch_size, seq_len, 
num_query_heads + num_kv_heads * T.int64(2), head_dim):
           with T.block("FusedRotaryEmbeddingAndSplitQKV"):
               batch_i, seq_i, head_num, head_i = T.axis.remap("SSSS", 
[iters_0, iters_1, iters_2, iters_3])
               T.reads(Fused_QKV[batch_i, seq_i, head_num, T.min(T.min(head_i, 
head_dim // T.int64(2) + head_i), head_i - head_dim // 
T.int64(2)):T.min(T.min(head_i, head_dim // T.int64(2) + head_i), head_i - 
head_dim // T.int64(2)) + (T.max(T.max(head_i, head_dim // T.int64(2) + 
head_i), head_i - head_dim // T.int64(2)) + T.int64(1) - T.min(T.min(head_i, 
head_dim // T.int64(2) + head_i), head_i - head_dim // T.int64(2)))])
               T.writes(EmbeddedQuery[batch_i, seq_i, head_num, head_i], 
EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i], Value[batch_i, 
seq_i, head_num - num_query_heads - num_kv_heads, head_i])
               pos: T.float32 = T.Cast("float32", rotary_offset + seq_i - 
seq_len)
               inv_freq: T.float32 = T.float32(1.0) / 
T.pow(position_embedding_base, T.Cast("float32", head_i * T.int64(2) % 
head_dim) / T.Cast("float32", head_dim))
               freq: T.float32 = pos * inv_freq
               cos_value: T.float16 = T.Cast("float16", T.cos(freq))
               sin_value: T.float16 = T.Cast("float16", T.sin(freq))
               input_value: T.float16 = Fused_QKV[batch_i, seq_i, head_num, 
head_i]
               embedded_value: T.float16 = cos_value * input_value + sin_value 
* T.Select(head_i < head_dim // T.int64(2), Fused_QKV[batch_i, seq_i, head_num, 
head_i + head_dim // T.int64(2)] * T.float16(-1.0), Fused_QKV[batch_i, seq_i, 
head_num, head_i - head_dim // T.int64(2)])
               if head_num < num_query_heads:
                   EmbeddedQuery[batch_i, seq_i, head_num, head_i] = 
embedded_value
               else:
                   if head_num < num_query_heads + num_kv_heads:
                       EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, 
head_i] = embedded_value
                   else:
                       Value[batch_i, seq_i, head_num - num_query_heads - 
num_kv_heads, head_i] = input_value has struct info 
R.Callable((R.Tensor((batch_size, seq_len, num_query_heads + num_kv_heads * 2, 
head_dim), dtype="float16"), R.Tensor((batch_size, seq_len, num_query_heads, 
head_dim), dtype="float16"), R.Tensor((batch_size, seq_len, num_kv_heads, 
head_dim), dtype="float16"), R.Tensor((batch_size, seq_len, num_kv_heads, 
head_dim), dtype="float16"), R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), 
R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), R.Prim("float32")), R.Tuple, 
False), which cannot be overwritten with R.Callable((R.Tensor((batch_size, 
seq_len, num_query_heads + num_kv_heads * 2, head_dim), dtype="float16"), 
R.Tensor((batch_size, seq_len, num_query_heads, head_dim), dtype="float16"), 
R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), 
R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), 
R.Prim("int64"), R.Prim("int64"), 
 R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), R.Prim("int64"), 
R.Prim("float32")), R.Tuple, False)
   ```
   
   ### Environment
   
   Any environment details, such as: Operating System, TVM version, etc
   
   ```
   Platform : Jetson Orin AGX, JetPack 6.2
   docker image : dustynv/nano_llm:r36.4.0
   package :  tvm==0.19.0, mlc-llm==0.1.0
   ```
   
   ### Steps to reproduce
   
   Using
   In `nano_llm:r36.4.0` container, 
   ```
     python3 -m nano_llm.vision.video \
       --api=mlc \
       --model liuhaotian/llava-v1.6-34b \
       --max-images 8 \
       --max-context-len 256 \
       --max-new-tokens 48 \
       --video-input <video input> \
       --video-output <video output> \
       --prompt <prompt>
   ```
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to