Skip to content

Commit

Permalink
Enable scala_macro_library targets to have dependencies (#1681)
Browse files Browse the repository at this point in the history
* Enabled ijar for scala_macro_library targets

* Export a ScalaInfo provider

* Include macros' transitive runtime dependencies on the compile classpath

---------

Co-authored-by: Jaden Peterson <jpeterson@lucidchart.com>
  • Loading branch information
jadenPete and Jaden Peterson authored Feb 5, 2025
1 parent bfb9b9e commit 3ca60fb
Show file tree
Hide file tree
Showing 16 changed files with 182 additions and 56 deletions.
29 changes: 27 additions & 2 deletions scala/private/common.bzl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
load("@io_bazel_rules_scala//scala:jars_to_labels.bzl", "JarsToLabelsInfo")
load("@io_bazel_rules_scala//scala:plusone.bzl", "PlusOneDeps")
load("@io_bazel_rules_scala//scala:providers.bzl", "ScalaInfo")
load("@bazel_skylib//lib:paths.bzl", "paths")

def write_manifest_file(actions, output_file, main_class):
Expand All @@ -22,6 +23,7 @@ def collect_jars(
compile_jars = []
runtime_jars = []
deps_providers = []
macro_classpath = []

for dep_target in dep_targets:
# we require a JavaInfo for dependencies
Expand Down Expand Up @@ -50,11 +52,34 @@ def collect_jars(
java_provider.compile_jars.to_list(),
)

# Macros are different from ordinary targets in that they’re used at compile time instead of at runtime. That
# means that both their compile-time classpath and runtime classpath are needed at compile time. We could have
# `scala_macro_library` targets include their runtime dependencies in their compile-time dependencies, but then
# we wouldn't have any guarantees classpath order.
#
# Consider the following scenario. Target A depends on targets B and C. Target C is a macro target, whereas
# target B isn't. Targets C depends on target B. If target A doesn't include the runtime version of target C on
# the compile classpath before the compile (`ijar`d) version of target B that target C depends on, then target A
# won't use the correct version of target B at compile-time when evaluating the macros contained in target C.
#
# For that reason, we opt for a different approach: have `scala_macro_library` targets export `JavaInfo`
# providers as normal, but put their transitive runtime dependencies first on the classpath. Note that we
# shouldn't encounter any issues with external dependencies, so long as they aren't `ijar`d.
if ScalaInfo in dep_target and dep_target[ScalaInfo].contains_macros:
macro_classpath.append(java_provider.transitive_runtime_jars)

add_labels_of_jars_to(
jars2labels,
dep_target,
[],
java_provider.transitive_runtime_jars.to_list(),
)

return struct(
compile_jars = depset(transitive = compile_jars),
compile_jars = depset(order = "preorder", transitive = macro_classpath + compile_jars),
transitive_runtime_jars = depset(transitive = runtime_jars),
jars2labels = JarsToLabelsInfo(jars_to_labels = jars2labels),
transitive_compile_jars = depset(transitive = transitive_compile_jars),
transitive_compile_jars = depset(order = "preorder", transitive = macro_classpath + transitive_compile_jars),
deps_providers = deps_providers,
)

Expand Down
11 changes: 0 additions & 11 deletions scala/private/phases/phase_compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,6 @@ def phase_compile_library_for_plugin_bootstrapping(ctx, p):
)
return _phase_compile_default(ctx, p, args)

def phase_compile_macro_library(ctx, p):
args = struct(
buildijar = False,
unused_dependency_checker_ignored_targets = [
target.label
for target in p.scalac_provider.default_macro_classpath + ctx.attr.exports +
ctx.attr.unused_dependency_checker_ignored_targets
],
)
return _phase_compile_default(ctx, p, args)

def phase_compile_junit_test(ctx, p):
args = struct(
buildijar = False,
Expand Down
14 changes: 14 additions & 0 deletions scala/private/phases/phase_scalainfo_provider.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
load("//scala:providers.bzl", "ScalaInfo")

def _phase_scalainfo_provider_implementation(contains_macros):
return struct(
external_providers = {
"ScalaInfo": ScalaInfo(contains_macros = contains_macros),
},
)

def phase_scalainfo_provider_macro(ctx, p):
return _phase_scalainfo_provider_implementation(contains_macros = True)

def phase_scalainfo_provider_non_macro(ctx, p):
return _phase_scalainfo_provider_implementation(contains_macros = False)
63 changes: 35 additions & 28 deletions scala/private/phases/phases.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,6 @@ load(
_extras_phases = "extras_phases",
_run_phases = "run_phases",
)
load(
"@io_bazel_rules_scala//scala/private:phases/phase_write_executable.bzl",
_phase_write_executable_common = "phase_write_executable_common",
_phase_write_executable_junit_test = "phase_write_executable_junit_test",
_phase_write_executable_repl = "phase_write_executable_repl",
_phase_write_executable_scalatest = "phase_write_executable_scalatest",
)
load(
"@io_bazel_rules_scala//scala/private:phases/phase_java_wrapper.bzl",
_phase_java_wrapper_common = "phase_java_wrapper_common",
_phase_java_wrapper_repl = "phase_java_wrapper_repl",
)
load(
"@io_bazel_rules_scala//scala/private:phases/phase_collect_jars.bzl",
_phase_collect_jars_common = "phase_collect_jars_common",
Expand All @@ -27,46 +15,62 @@ load(
_phase_collect_jars_repl = "phase_collect_jars_repl",
_phase_collect_jars_scalatest = "phase_collect_jars_scalatest",
)
load("@io_bazel_rules_scala//scala/private:phases/phase_collect_exports_jars.bzl", _phase_collect_exports_jars = "phase_collect_exports_jars")
load("@io_bazel_rules_scala//scala/private:phases/phase_collect_srcjars.bzl", _phase_collect_srcjars = "phase_collect_srcjars")
load(
"@io_bazel_rules_scala//scala/private:phases/phase_compile.bzl",
_phase_compile_binary = "phase_compile_binary",
_phase_compile_common = "phase_compile_common",
_phase_compile_junit_test = "phase_compile_junit_test",
_phase_compile_library = "phase_compile_library",
_phase_compile_library_for_plugin_bootstrapping = "phase_compile_library_for_plugin_bootstrapping",
_phase_compile_macro_library = "phase_compile_macro_library",
_phase_compile_repl = "phase_compile_repl",
_phase_compile_scalatest = "phase_compile_scalatest",
)
load(
"@io_bazel_rules_scala//scala/private:phases/phase_runfiles.bzl",
_phase_runfiles_common = "phase_runfiles_common",
_phase_runfiles_library = "phase_runfiles_library",
_phase_runfiles_scalatest = "phase_runfiles_scalatest",
)
load(
"@io_bazel_rules_scala//scala/private:phases/phase_coverage.bzl",
_phase_coverage_common = "phase_coverage_common",
_phase_coverage_library = "phase_coverage_library",
)
load("@io_bazel_rules_scala//scala/private:phases/phase_coverage_runfiles.bzl", _phase_coverage_runfiles = "phase_coverage_runfiles")
load("@io_bazel_rules_scala//scala/private:phases/phase_declare_executable.bzl", _phase_declare_executable = "phase_declare_executable")
load("@io_bazel_rules_scala//scala/private:phases/phase_default_info.bzl", _phase_default_info = "phase_default_info")
load("@io_bazel_rules_scala//scala/private:phases/phase_scalac_provider.bzl", _phase_scalac_provider = "phase_scalac_provider")
load("@io_bazel_rules_scala//scala/private:phases/phase_write_manifest.bzl", _phase_write_manifest = "phase_write_manifest")
load("@io_bazel_rules_scala//scala/private:phases/phase_collect_srcjars.bzl", _phase_collect_srcjars = "phase_collect_srcjars")
load("@io_bazel_rules_scala//scala/private:phases/phase_collect_exports_jars.bzl", _phase_collect_exports_jars = "phase_collect_exports_jars")
load(
"@io_bazel_rules_scala//scala/private:phases/phase_dependency.bzl",
_phase_dependency_common = "phase_dependency_common",
_phase_dependency_library_for_plugin_bootstrapping = "phase_dependency_library_for_plugin_bootstrapping",
)
load("@io_bazel_rules_scala//scala/private:phases/phase_declare_executable.bzl", _phase_declare_executable = "phase_declare_executable")
load("@io_bazel_rules_scala//scala/private:phases/phase_merge_jars.bzl", _phase_merge_jars = "phase_merge_jars")
load(
"@io_bazel_rules_scala//scala/private:phases/phase_java_wrapper.bzl",
_phase_java_wrapper_common = "phase_java_wrapper_common",
_phase_java_wrapper_repl = "phase_java_wrapper_repl",
)
load("@io_bazel_rules_scala//scala/private:phases/phase_jvm_flags.bzl", _phase_jvm_flags = "phase_jvm_flags")
load("@io_bazel_rules_scala//scala/private:phases/phase_merge_jars.bzl", _phase_merge_jars = "phase_merge_jars")
load(
"@io_bazel_rules_scala//scala/private:phases/phase_runfiles.bzl",
_phase_runfiles_common = "phase_runfiles_common",
_phase_runfiles_library = "phase_runfiles_library",
_phase_runfiles_scalatest = "phase_runfiles_scalatest",
)
load("@io_bazel_rules_scala//scala/private:phases/phase_scalac_provider.bzl", _phase_scalac_provider = "phase_scalac_provider")
load("@io_bazel_rules_scala//scala/private:phases/phase_scalacopts.bzl", _phase_scalacopts = "phase_scalacopts")
load("@io_bazel_rules_scala//scala/private:phases/phase_coverage_runfiles.bzl", _phase_coverage_runfiles = "phase_coverage_runfiles")
load("@io_bazel_rules_scala//scala/private:phases/phase_scalafmt.bzl", _phase_scalafmt = "phase_scalafmt")
load("@io_bazel_rules_scala//scala/private:phases/phase_test_environment.bzl", _phase_test_environment = "phase_test_environment")
load(
"@io_bazel_rules_scala//scala/private:phases/phase_scalainfo_provider.bzl",
_phase_scalainfo_provider_macro = "phase_scalainfo_provider_macro",
_phase_scalainfo_provider_non_macro = "phase_scalainfo_provider_non_macro",
)
load("@io_bazel_rules_scala//scala/private:phases/phase_semanticdb.bzl", _phase_semanticdb = "phase_semanticdb")
load("@io_bazel_rules_scala//scala/private:phases/phase_test_environment.bzl", _phase_test_environment = "phase_test_environment")
load(
"@io_bazel_rules_scala//scala/private:phases/phase_write_executable.bzl",
_phase_write_executable_common = "phase_write_executable_common",
_phase_write_executable_junit_test = "phase_write_executable_junit_test",
_phase_write_executable_repl = "phase_write_executable_repl",
_phase_write_executable_scalatest = "phase_write_executable_scalatest",
)
load("@io_bazel_rules_scala//scala/private:phases/phase_write_manifest.bzl", _phase_write_manifest = "phase_write_manifest")

# API
run_phases = _run_phases
Expand All @@ -75,6 +79,10 @@ extras_phases = _extras_phases
# scalac_provider
phase_scalac_provider = _phase_scalac_provider

# scalainfo_provider
phase_scalainfo_provider_macro = _phase_scalainfo_provider_macro
phase_scalainfo_provider_non_macro = _phase_scalainfo_provider_non_macro

# collect_srcjars
phase_collect_srcjars = _phase_collect_srcjars

Expand Down Expand Up @@ -128,7 +136,6 @@ phase_collect_jars_common = _phase_collect_jars_common
phase_compile_binary = _phase_compile_binary
phase_compile_library = _phase_compile_library
phase_compile_library_for_plugin_bootstrapping = _phase_compile_library_for_plugin_bootstrapping
phase_compile_macro_library = _phase_compile_macro_library
phase_compile_junit_test = _phase_compile_junit_test
phase_compile_repl = _phase_compile_repl
phase_compile_scalatest = _phase_compile_scalatest
Expand Down
2 changes: 2 additions & 0 deletions scala/private/rules/scala_binary.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ load(
"phase_runfiles_common",
"phase_scalac_provider",
"phase_scalacopts",
"phase_scalainfo_provider_non_macro",
"phase_semanticdb",
"phase_write_executable_common",
"phase_write_manifest",
Expand All @@ -36,6 +37,7 @@ def _scala_binary_impl(ctx):
# customizable phases
[
("scalac_provider", phase_scalac_provider),
("scalainfo_provider", phase_scalainfo_provider_non_macro),
("write_manifest", phase_write_manifest),
("dependency", phase_dependency_common),
("collect_jars", phase_collect_jars_common),
Expand Down
19 changes: 15 additions & 4 deletions scala/private/rules/scala_doc.bzl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Scaladoc support"""

load("@io_bazel_rules_scala//scala:providers.bzl", "ScalaInfo")
load("@io_bazel_rules_scala//scala/private:common.bzl", "collect_plugin_paths")

ScaladocAspectInfo = provider(fields = [
"src_files", #depset[File]
"compile_jars", #depset[File]
"macro_classpath", #depset[File]
"plugins", #depset[Target]
])

Expand All @@ -29,23 +31,28 @@ def _scaladoc_aspect_impl(target, ctx, transitive = True):
if hasattr(ctx.rule.attr, "plugins"):
plugins = depset(ctx.rule.attr.plugins)

macro_classpath = []

for dependency in ctx.rule.attr.deps:
if ScalaInfo in dependency and dependency[ScalaInfo].contains_macros:
macro_classpath.append(dependency[JavaInfo].transitive_runtime_jars)

# Sometimes we only want to generate scaladocs for a single target and not all of its
# dependencies
transitive_srcs = depset()
transitive_compile_jars = depset()
transitive_plugins = depset()

if transitive:
for dep in ctx.rule.attr.deps:
if ScaladocAspectInfo in dep:
aspec_info = dep[ScaladocAspectInfo]
transitive_srcs = aspec_info.src_files
transitive_compile_jars = aspec_info.compile_jars
transitive_plugins = aspec_info.plugins

return [ScaladocAspectInfo(
src_files = depset(transitive = [src_files, transitive_srcs]),
compile_jars = depset(transitive = [compile_jars, transitive_compile_jars]),
compile_jars = depset(transitive = [compile_jars]),
macro_classpath = depset(transitive = macro_classpath),
plugins = depset(transitive = [plugins, transitive_plugins]),
)]

Expand Down Expand Up @@ -73,11 +80,15 @@ def _scala_doc_impl(ctx):
src_files = depset(transitive = [dep[ScaladocAspectInfo].src_files for dep in ctx.attr.deps])
compile_jars = depset(transitive = [dep[ScaladocAspectInfo].compile_jars for dep in ctx.attr.deps])

# See the documentation for `collect_jars` in `scala/private/common.bzl` to understand why this is prepended to the
# classpath
macro_classpath = depset(transitive = [dep[ScaladocAspectInfo].macro_classpath for dep in ctx.attr.deps])

# Get the 'real' paths to the plugin jars.
plugins = collect_plugin_paths(depset(transitive = [dep[ScaladocAspectInfo].plugins for dep in ctx.attr.deps]).to_list())

# Construct the full classpath depset since we need to add compiler plugins too.
classpath = depset(transitive = [plugins, compile_jars])
classpath = depset(transitive = [macro_classpath, plugins, compile_jars])

# Construct scaladoc args, which also include scalac args.
# See `scaladoc -help` for more information.
Expand Down
2 changes: 2 additions & 0 deletions scala/private/rules/scala_junit_test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ load(
"phase_runfiles_common",
"phase_scalac_provider",
"phase_scalacopts",
"phase_scalainfo_provider_non_macro",
"phase_semanticdb",
"phase_test_environment",
"phase_write_executable_junit_test",
Expand All @@ -42,6 +43,7 @@ def _scala_junit_test_impl(ctx):
# customizable phases
[
("scalac_provider", phase_scalac_provider),
("scalainfo_provider", phase_scalainfo_provider_non_macro),
("write_manifest", phase_write_manifest),
("dependency", phase_dependency_common),
("collect_jars", phase_collect_jars_junit_test),
Expand Down
8 changes: 6 additions & 2 deletions scala/private/rules/scala_library.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ load(
"phase_collect_srcjars",
"phase_compile_library",
"phase_compile_library_for_plugin_bootstrapping",
"phase_compile_macro_library",
"phase_coverage_common",
"phase_coverage_library",
"phase_default_info",
Expand All @@ -35,6 +34,8 @@ load(
"phase_runfiles_library",
"phase_scalac_provider",
"phase_scalacopts",
"phase_scalainfo_provider_macro",
"phase_scalainfo_provider_non_macro",
"phase_semanticdb",
"phase_write_manifest",
"run_phases",
Expand Down Expand Up @@ -63,6 +64,7 @@ def _scala_library_impl(ctx):
# customizable phases
[
("scalac_provider", phase_scalac_provider),
("scalainfo_provider", phase_scalainfo_provider_non_macro),
("collect_srcjars", phase_collect_srcjars),
("write_manifest", phase_write_manifest),
("dependency", phase_dependency_common),
Expand Down Expand Up @@ -151,6 +153,7 @@ def _scala_library_for_plugin_bootstrapping_impl(ctx):
# customizable phases
[
("scalac_provider", phase_scalac_provider),
("scalainfo_provider", phase_scalainfo_provider_non_macro),
("collect_srcjars", phase_collect_srcjars),
("write_manifest", phase_write_manifest),
("dependency", phase_dependency_library_for_plugin_bootstrapping),
Expand Down Expand Up @@ -226,13 +229,14 @@ def _scala_macro_library_impl(ctx):
# customizable phases
[
("scalac_provider", phase_scalac_provider),
("scalainfo_provider", phase_scalainfo_provider_macro),
("collect_srcjars", phase_collect_srcjars),
("write_manifest", phase_write_manifest),
("dependency", phase_dependency_common),
("collect_jars", phase_collect_jars_macro_library),
("scalacopts", phase_scalacopts),
("semanticdb", phase_semanticdb),
("compile", phase_compile_macro_library),
("compile", phase_compile_library),
("coverage", phase_coverage_common),
("merge_jars", phase_merge_jars),
("runfiles", phase_runfiles_library),
Expand Down
2 changes: 2 additions & 0 deletions scala/private/rules/scala_repl.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ load(
"phase_runfiles_common",
"phase_scalac_provider",
"phase_scalacopts",
"phase_scalainfo_provider_non_macro",
"phase_semanticdb",
"phase_write_executable_repl",
"phase_write_manifest",
Expand All @@ -36,6 +37,7 @@ def _scala_repl_impl(ctx):
# customizable phases
[
("scalac_provider", phase_scalac_provider),
("scalainfo_provider", phase_scalainfo_provider_non_macro),
("write_manifest", phase_write_manifest),
("dependency", phase_dependency_common),
# need scala-compiler for MainGenericRunner below
Expand Down
2 changes: 2 additions & 0 deletions scala/private/rules/scala_test.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ load(
"phase_runfiles_scalatest",
"phase_scalac_provider",
"phase_scalacopts",
"phase_scalainfo_provider_non_macro",
"phase_semanticdb",
"phase_test_environment",
"phase_write_executable_scalatest",
Expand All @@ -38,6 +39,7 @@ def _scala_test_impl(ctx):
# customizable phases
[
("scalac_provider", phase_scalac_provider),
("scalainfo_provider", phase_scalainfo_provider_non_macro),
("write_manifest", phase_write_manifest),
("dependency", phase_dependency_common),
("collect_jars", phase_collect_jars_scalatest),
Expand Down
Loading

0 comments on commit 3ca60fb

Please sign in to comment.