From 102ecf91bf8f0259f1922bb9b1a8c3ec23befd21 Mon Sep 17 00:00:00 2001 From: Travis Cline Date: Wed, 6 Sep 2023 20:37:06 -0700 Subject: [PATCH] codegen: Handle a few more cases, populate some mps structs --- generate/codegen/gen_function.go | 2 ++ macos/mps/functions.gen.go | 14 +++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/generate/codegen/gen_function.go b/generate/codegen/gen_function.go index 6ecfb603..2852a63c 100644 --- a/generate/codegen/gen_function.go +++ b/generate/codegen/gen_function.go @@ -177,6 +177,8 @@ func (f *Function) WriteGoCallCode(currentModule *modules.Module, cw *CodeWriter sb.WriteString(cw.IndentStr + fmt.Sprintf(" (*C.%s)(unsafe.Pointer(&%s))", tt.CName(), p.GoName())) case *typing.IDType: sb.WriteString(cw.IndentStr + fmt.Sprintf(" %s.Ptr()", p.GoName())) + case *typing.ClassType, *typing.ProtocolType: + sb.WriteString(cw.IndentStr + fmt.Sprintf(" unsafe.Pointer(&%s)", p.GoName())) default: sb.WriteString(cw.IndentStr + p.GoName()) } diff --git a/macos/mps/functions.gen.go b/macos/mps/functions.gen.go index c4d20482..b8ac9e71 100644 --- a/macos/mps/functions.gen.go +++ b/macos/mps/functions.gen.go @@ -49,7 +49,7 @@ func StateBatchResourceSize(batch *foundation.Array) uint { func HintTemporaryMemoryHighWaterMark(cmdBuf metal.CommandBufferWrapper, bytes uint) { C.HintTemporaryMemoryHighWaterMark( // *typing.ProtocolType - cmdBuf, + unsafe.Pointer(&cmdBuf), // *typing.PrimitiveType C.uint(bytes), ) @@ -73,7 +73,7 @@ func ImageBatchResourceSize(batch *foundation.Array) uint { func SetHeapCacheDuration(cmdBuf metal.CommandBufferWrapper, seconds float64) { C.SetHeapCacheDuration( // *typing.ProtocolType - cmdBuf, + unsafe.Pointer(&cmdBuf), // *typing.PrimitiveType C.double(seconds), ) @@ -101,7 +101,7 @@ func StateBatchSynchronize(batch *foundation.Array, cmdBuf metal.CommandBufferWr // *typing.PointerType (*C.MPSStateBatch)(unsafe.Pointer(&batch)), // *typing.ProtocolType - cmdBuf, + unsafe.Pointer(&cmdBuf), ) } @@ -180,10 +180,10 @@ func GetCustomKernelBroadcastSourceIndex(c CustomKernelArgumentCount, sourceInde func GetImageType(image Image) ImageType { rv := C.GetImageType( // *typing.ClassType - image, + unsafe.Pointer(&image), ) // *typing.AliasType - return ImageType(rv) + return *(*ImageType)(unsafe.Pointer(&rv)) } // Returns the integer division parameters for a specified divisor. [Full Topic] @@ -222,7 +222,7 @@ func ImageBatchSynchronize(batch *foundation.Array, cmdBuf metal.CommandBufferWr // *typing.PointerType (*C.MPSImageBatch)(unsafe.Pointer(&batch)), // *typing.ProtocolType - cmdBuf, + unsafe.Pointer(&cmdBuf), ) } @@ -232,7 +232,7 @@ func ImageBatchSynchronize(batch *foundation.Array, cmdBuf metal.CommandBufferWr func SupportsMTLDevice(device metal.DeviceWrapper) bool { rv := C.SupportsMTLDevice( // *typing.ProtocolType - device, + unsafe.Pointer(&device), ) // *typing.PrimitiveType return bool(rv)