Run with mlx_lm on macos if encountered error with LLM Studio
#5
by stiron9 - opened
Use following python script:
#!/usr/bin/env python3
import argparse
import json
from pathlib import Path
from mlx_lm import generate, load
from mlx_lm.sample_utils import make_sampler
from safetensors import safe_open
DEFAULT_MODEL = "/path/to/dealignai/Gemma-4-31B-JANG_4M-CRACK"
def _to_mlx_module_path(weight_key: str) -> str:
key = weight_key.removeprefix("model.")
if key.startswith("language_model."):
key = key.replace("language_model.", "language_model.model.", 1)
return key.removesuffix(".weight")
def _infer_quantization_config(model_dir: str) -> dict:
config_path = Path(model_dir) / "config.json"
with config_path.open("r", encoding="utf-8") as f:
cfg = json.load(f)
base_quant = dict(cfg.get("quantization", {}))
group_size = int(base_quant.get("group_size", 64))
default_bits = int(base_quant.get("bits", 4))
quantization = {"group_size": group_size, "bits": default_bits}
for shard in sorted(Path(model_dir).glob("model-*.safetensors")):
with safe_open(str(shard), framework="np") as sf:
keys = list(sf.keys())
keyset = set(keys)
for key in keys:
if not key.endswith(".weight"):
continue
scales_key = key.replace(".weight", ".scales")
if scales_key not in keyset:
continue
weight_shape = sf.get_slice(key).get_shape()
scales_shape = sf.get_slice(scales_key).get_shape()
if len(weight_shape) != 2 or len(scales_shape) != 2:
continue
packed_in = int(weight_shape[1])
n_groups = int(scales_shape[1])
if n_groups <= 0:
continue
in_features = n_groups * group_size
bits = (32 * packed_in) // in_features
if bits != default_bits:
quantization[_to_mlx_module_path(key)] = {
"group_size": group_size,
"bits": bits,
}
return quantization
def main() -> None:
parser = argparse.ArgumentParser(description="Run local model with mlx-lm")
parser.add_argument("--model", default=DEFAULT_MODEL, help="Local model directory")
parser.add_argument("--prompt", default="hi who are u", help="Prompt text")
parser.add_argument("--max-tokens", type=int, default=256, help="Max new tokens")
parser.add_argument("--temp", type=float, default=0.7, help="Sampling temperature")
args = parser.parse_args()
model_config = {"quantization": _infer_quantization_config(args.model)}
model, tokenizer = load(args.model, model_config=model_config)
output = generate(
model,
tokenizer,
prompt=args.prompt,
max_tokens=args.max_tokens,
sampler=make_sampler(temp=args.temp),
)
print(output)
if __name__ == "__main__":
main()