Run with mlx_lm on macos if encountered error with LLM Studio

#5
by stiron9 - opened

Use following python script:

#!/usr/bin/env python3
import argparse
import json
from pathlib import Path

from mlx_lm import generate, load
from mlx_lm.sample_utils import make_sampler
from safetensors import safe_open

DEFAULT_MODEL = "/path/to/dealignai/Gemma-4-31B-JANG_4M-CRACK"


def _to_mlx_module_path(weight_key: str) -> str:
    key = weight_key.removeprefix("model.")
    if key.startswith("language_model."):
        key = key.replace("language_model.", "language_model.model.", 1)
    return key.removesuffix(".weight")


def _infer_quantization_config(model_dir: str) -> dict:
    config_path = Path(model_dir) / "config.json"
    with config_path.open("r", encoding="utf-8") as f:
        cfg = json.load(f)

    base_quant = dict(cfg.get("quantization", {}))
    group_size = int(base_quant.get("group_size", 64))
    default_bits = int(base_quant.get("bits", 4))

    quantization = {"group_size": group_size, "bits": default_bits}

    for shard in sorted(Path(model_dir).glob("model-*.safetensors")):
        with safe_open(str(shard), framework="np") as sf:
            keys = list(sf.keys())
            keyset = set(keys)
            for key in keys:
                if not key.endswith(".weight"):
                    continue
                scales_key = key.replace(".weight", ".scales")
                if scales_key not in keyset:
                    continue

                weight_shape = sf.get_slice(key).get_shape()
                scales_shape = sf.get_slice(scales_key).get_shape()
                if len(weight_shape) != 2 or len(scales_shape) != 2:
                    continue

                packed_in = int(weight_shape[1])
                n_groups = int(scales_shape[1])
                if n_groups <= 0:
                    continue

                in_features = n_groups * group_size
                bits = (32 * packed_in) // in_features
                if bits != default_bits:
                    quantization[_to_mlx_module_path(key)] = {
                        "group_size": group_size,
                        "bits": bits,
                    }

    return quantization


def main() -> None:
    parser = argparse.ArgumentParser(description="Run local model with mlx-lm")
    parser.add_argument("--model", default=DEFAULT_MODEL, help="Local model directory")
    parser.add_argument("--prompt", default="hi  who are u", help="Prompt text")
    parser.add_argument("--max-tokens", type=int, default=256, help="Max new tokens")
    parser.add_argument("--temp", type=float, default=0.7, help="Sampling temperature")
    args = parser.parse_args()

    model_config = {"quantization": _infer_quantization_config(args.model)}
    model, tokenizer = load(args.model, model_config=model_config)

    output = generate(
        model,
        tokenizer,
        prompt=args.prompt,
        max_tokens=args.max_tokens,
        sampler=make_sampler(temp=args.temp),
    )
    print(output)


if __name__ == "__main__":
    main()

Sign up or log in to comment