Skip to content

Commit

Permalink
msl wip
Browse files Browse the repository at this point in the history
  • Loading branch information
polymonster committed Sep 20, 2024
1 parent baffb84 commit 6171eff
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
5 changes: 5 additions & 0 deletions pmfx.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class BuildInfo:
metal_sdk = "" # macosx, iphoneos, appletvos
metal_min_os = "" # iOS (9.0 - 13.0), macOS (10.11 - 10.15)
metal_version = "2.0" # MSL version
discrete_binding = None # stage to pass to --msl-discrete-descriptor-set (spirv-cross)
debug = False # generate shader with debug info
inputs = [] # array of input files or directories
extensions = [] # array of shader extension currently for glsl/gles
Expand Down Expand Up @@ -167,6 +168,8 @@ def parse_args():
_info.cbuffer_offset = sys.argv[i + 1]
elif sys.argv[i] == "-stage_in":
_info.stage_in = sys.argv[i + 1]
elif sys.argv[i] == "-discrete_binding":
_info.discrete_binding = sys.argv[i + 1]
elif sys.argv[i] == "-v_flip":
_info.v_flip = True
elif sys.argv[i] == "-d":
Expand Down Expand Up @@ -247,6 +250,8 @@ def display_help():
print(" -source (optional) (generates platform source into -o no compilation)")
print(" -stage_in <0, 1> (optional) [metal only] (default 1) ")
print(" uses stage_in for metal vertex buffers, 0 uses raw buffers")
print(" -discrete_binding <int> (optional) [metal only] (default None) ")
print(" any resources bound on this space will be discretely bound. eg. static samplers, push constants")
print(" -cbuffer_offset (optional) [metal only] (default 4) ")
print(" specifies an offset applied to cbuffer locations to avoid collisions with vertex buffers")
print(" -texture_offset (optional) [vulkan only] (default 32) ")
Expand Down
36 changes: 29 additions & 7 deletions pmfx_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import zlib
import sys
import jsn
import re
import pprint

from multiprocessing.pool import ThreadPool

Expand Down Expand Up @@ -725,7 +727,7 @@ def to_spirv_msl_version(metal_version):


# cross compile hlsl -> spirv -> metal
def cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, output_filepath):
def cross_compile_hlsl_metal(info, shader_info, stage, entry_point, temp_filepath, output_filepath):
exe = os.path.join(info.tools_dir, "bin", "macos", "dxc")

metal_sdk, metal_min_os, metal_version = metal_compile_options(info)
Expand All @@ -743,12 +745,22 @@ def cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, outpu

spirv_msl_version = to_spirv_msl_version(info.metal_version)
metal_filepath = os.path.splitext(temp_filepath)[0] + ".metal"
cmdline = "{} --msl --msl-version {} {} --output {}".format(spirv_cross, spirv_msl_version, spirv_filepath, metal_filepath)
spirv_cross_msl_args = "--msl-argument-buffers --msl-argument-buffer-tier 1"

if info.discrete_binding:
spirv_cross_msl_args += f" --msl-discrete-descriptor-set {info.discrete_binding}"

cmdline = "{} --msl --msl-version {} {} {} --output {}".format(spirv_cross, spirv_msl_version, spirv_cross_msl_args, spirv_filepath, metal_filepath)

ec, el, ol = build_pmfx.call_wait_subprocess(cmdline)
error_list += el
output_list += ol

'''
if "resources" in shader_info and "push_constants" in shader_info["resources"]:
patch_msl_push_constants(shader_info["resources"]["push_constants"], stage, entry_point, metal_filepath)
'''

air_filepath = os.path.splitext(temp_filepath)[0] + ".air"
cmdline = "xcrun -sdk {} metal {} {} -c -frecord-sources {} -o {}".format(metal_sdk, metal_min_os, metal_version, metal_filepath, air_filepath)
ec, el, ol = build_pmfx.call_wait_subprocess(cmdline)
Expand All @@ -764,16 +776,16 @@ def cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, outpu


# compile a hlsl version 2
def compile_shader_hlsl(info, src, stage, entry_point, temp_filepath, output_filepath):
def compile_shader_hlsl(info, shader_info, stage, entry_point, temp_filepath, output_filepath):
exe = os.path.join(info.tools_dir, "bin", "dxc", "dxc")
open(temp_filepath, "w+").write(src)
open(temp_filepath, "w+").write(shader_info["src"])
# compile... or skip
error_code = 0
error_list = []
output_list = []
if info.compiled:
if info.shader_platform == "metal":
error_code, error_list, output_list = cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, output_filepath)
error_code, error_list, output_list = cross_compile_hlsl_metal(info, shader_info, stage, entry_point, temp_filepath, output_filepath)
elif info.shader_platform == "hlsl":
cmdline = "{} -T {}_{} -E {} -Fo {} {}".format(exe, stage, info.shader_version, entry_point, output_filepath, temp_filepath)
cmdline += " " + build_pmfx.get_info().args
Expand Down Expand Up @@ -886,9 +898,19 @@ def generate_shader_info(pmfx, entry_point, stage, permute=None):
complete = False
break

# now add used resource src decls
# sort resources by space and register
sorted_resources = dict()
for category in resource_categories:
sorted_resources[category] = list()
for r in pmfx["resources"][category]:
sorted_resources[category].append(pmfx["resources"][category][r])
# sort the category
sorted_resources[category] = sorted(sorted_resources[category], key=lambda x: (x["register_space"], x["shader_register"]))

# now add used resource src decls
for category in resource_categories:
for r in sorted_resources[category]:
r = r["name"]
tokens = [r]
resource = pmfx["resources"][category][r]
# cbuffers with inline decl need to check for usage per member
Expand Down Expand Up @@ -1006,7 +1028,7 @@ def generate_shader_permutation(build_info, shader_info, stage, entry_point, pmf
output_filepath = os.path.join(output_path, filename + "c")
if shader_needs_compiling(pmfx, entry_point, shader_info["src_hash"], output_filepath):
temp_filepath = os.path.join(temp_path, filename)
shader_info["error_code"] = compile_shader_hlsl(build_info, shader_info["src"], stage, entry_point, temp_filepath, output_filepath)
shader_info["error_code"] = compile_shader_hlsl(build_info, shader_info, stage, entry_point, temp_filepath, output_filepath)
shader_info["filename"] = "{}/{}c".format(pmfx["pmfx_name"], filename)
return (stage, entry_point, shader_info)

Expand Down

0 comments on commit 6171eff

Please sign in to comment.