Allow passing in --device="mps": ie: choices=["cuda", "cpu", "mps"]
Set kwargs: kwargs = { "torch_dtype": torch.float16 }
then adding to("mps") on line 98: model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, *kwargs).to('mps')
commenting out: raise ValueError(f"Invalid device: {args.device}")
and changing cuda to mps on line 80: if args.device == "mps":
I'm not sure it's working correctly but at least it's a step. It's told me how to catch a duck but it often falls into some "renewable energy" sequence. :D
Allow passing in --device="mps": ie: choices=["cuda", "cpu", "mps"]
Set kwargs: kwargs = { "torch_dtype": torch.float16 }
then adding to("mps") on line 98: model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, *kwargs).to('mps')
commenting out: raise ValueError(f"Invalid device: {args.device}")
and changing cuda to mps on line 80: if args.device == "mps":
I'm not sure it's working correctly but at least it's a step. It's told me how to catch a duck but it often falls into some "renewable energy" sequence. :D