Skip to content

Commit

Permalink
Adds additional operations for working with MDC in suspending
Browse files Browse the repository at this point in the history
and non-suspending contexts

GitOrigin-RevId: 2f1a75b27aa07cf9508a61ae9e10f60370aa8e3d
  • Loading branch information
jclyne authored and svc-squareup-copybara committed Feb 6, 2025
1 parent 57778b6 commit b9426c8
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 105 deletions.
1 change: 1 addition & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ kotlinStdLibJdk8 = { module = "org.jetbrains.kotlin:kotlin-stdlib-jdk8", version
kotlinTest = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" }
kotlinxHtml = { module = "org.jetbrains.kotlinx:kotlinx-html-jvm", version = "0.12.0" }
kotlinxCoroutinesCore = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version = "1.10.1" }
kotlinxCoroutinesSlf4j = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-slf4j", version = "1.10.1" }
kotlinxCoroutinesTest = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version = "1.10.1" }
kubernetesClient = { module = "io.kubernetes:client-java", version = "18.0.1" }
kubernetesClientApi = { module = "io.kubernetes:client-java-api", version = "18.0.1" }
Expand Down
2 changes: 2 additions & 0 deletions misk-api/api/misk-api.api
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ public final class misk/client/NetworkInterceptorWrapper : okhttp3/Interceptor {
public abstract interface class misk/logging/Mdc {
public abstract fun clear ()V
public abstract fun get (Ljava/lang/String;)Ljava/lang/String;
public abstract fun getCopyOfContextMap ()Ljava/util/Map;
public abstract fun put (Ljava/lang/String;Ljava/lang/String;)V
public abstract fun setContextMap (Ljava/util/Map;)V
}

public abstract interface class misk/scope/ActionScoped {
Expand Down
7 changes: 7 additions & 0 deletions misk-api/src/main/kotlin/misk/logging/Mdc.kt
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
package misk.logging

typealias MdcContextMap = Map<String, String>

interface Mdc {
fun put(key: String, value: String?)

fun get(key: String): String?

fun clear()

fun setContextMap(context: MdcContextMap)

fun getCopyOfContextMap(): MdcContextMap?

}

12 changes: 12 additions & 0 deletions misk/api/misk.api
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,19 @@ public final class misk/logging/MiskMdc : misk/logging/Mdc {
public static final field INSTANCE Lmisk/logging/MiskMdc;
public fun clear ()V
public fun get (Ljava/lang/String;)Ljava/lang/String;
public fun getCopyOfContextMap ()Ljava/util/Map;
public fun put (Ljava/lang/String;Ljava/lang/String;)V
public fun setContextMap (Ljava/util/Map;)V
}

public final class misk/logging/ScopedMdcKt {
public static final fun withMdc (Lmisk/logging/Mdc;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function0;)V
public static final fun withMdc (Lmisk/logging/Mdc;[Lkotlin/Pair;Lkotlin/jvm/functions/Function0;)V
}

public final class misk/logging/coroutines/ScopedMdcKt {
public static final fun withMdc (Lmisk/logging/Mdc;Ljava/lang/String;Ljava/lang/String;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun withMdc (Lmisk/logging/Mdc;[Lkotlin/Pair;Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class misk/monitoring/JvmMetrics {
Expand Down
2 changes: 2 additions & 0 deletions misk/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies {
implementation(libs.kotlinReflect)
implementation(libs.kotlinStdLibJdk8)
implementation(libs.kotlinxCoroutinesCore)
implementation(libs.kotlinxCoroutinesSlf4j)
implementation(libs.moshiAdapters)
implementation(libs.okio)
implementation(libs.openTracingConcurrent)
Expand Down Expand Up @@ -81,6 +82,7 @@ dependencies {
testImplementation(libs.junitParams)
testImplementation(libs.kotestAssertions)
testImplementation(libs.kotlinTest)
testImplementation(libs.kotlinxCoroutinesCore)
testImplementation(libs.kotlinxCoroutinesTest)
testImplementation(libs.logbackClassic)
testImplementation(libs.okHttpMockWebServer)
Expand Down
55 changes: 0 additions & 55 deletions misk/src/main/kotlin/misk/logging/DynamicMdcContext.kt

This file was deleted.

10 changes: 10 additions & 0 deletions misk/src/main/kotlin/misk/logging/MiskMdc.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,14 @@ object MiskMdc : Mdc {
override fun clear() {
MDC.clear()
}

override fun setContextMap(context: MdcContextMap) {
if (context.isNotEmpty()) {
MDC.setContextMap(context)
} else {
MDC.clear()
}
}

override fun getCopyOfContextMap(): MdcContextMap? = MDC.getCopyOfContextMap()
}
23 changes: 23 additions & 0 deletions misk/src/main/kotlin/misk/logging/ScopedMdc.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package misk.logging

import org.slf4j.MDC

/**
* Adds the given key, value pair to the MDC for the duration of the block.
*/
inline fun Mdc.withMdc(key: String, value: String, block: () -> Unit) =
withMdc(key to value, block = block)

/**
* Adds the given tags to the MDC for the duration of the block.
*/
inline fun Mdc.withMdc(vararg tags: Pair<String, String>, block: () -> Unit) {
val oldState = getCopyOfContextMap()
return try {
tags.forEach { (key, value) -> put(key, value) }
block()
} finally {
oldState?.let { setContextMap(it) } ?: clear()
}
}

35 changes: 35 additions & 0 deletions misk/src/main/kotlin/misk/logging/coroutines/ScopedMdc.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package misk.logging.coroutines

import kotlinx.coroutines.slf4j.MDCContext
import kotlinx.coroutines.withContext
import misk.logging.Mdc
import mu.KotlinLogging
import wisp.logging.getLogger
import kotlin.coroutines.coroutineContext

/**
* Adds the given key, value pair to the MDC for the duration of the block.
* This is coroutine safe, so the additions will be added to the coroutine context
*/
suspend inline fun Mdc.withMdc(key: String, value: String, crossinline block: suspend () -> Unit) =
withMdc(key to value, block = block)

/**
* Adds the given tags to the MDC for the duration of the block.
* This is coroutine safe, so the additions will be added to the coroutine context
*/
suspend inline fun Mdc.withMdc(
vararg tags: Pair<String, String>,
crossinline block: suspend () -> Unit
) {
if(coroutineContext[MDCContext] == null) {
KotlinLogging.logger("misk.logging.coroutines.ScopedMdc").warn {
"MDCContext is not present in the coroutine context, this is required to restore the previous MDC state"
}
}
tags.forEach { (key, value) -> put(key, value) }
return withContext(MDCContext()) {
block()
}
}

4 changes: 2 additions & 2 deletions misk/src/main/kotlin/misk/web/actions/WebActions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package misk.web.actions

import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.slf4j.MDCContext
import misk.ApplicationInterceptor
import misk.Chain
import misk.grpc.GrpcMessageSinkChannel
import misk.grpc.GrpcMessageSourceChannel
import misk.logging.DynamicMdcContext
import misk.scope.ActionScope
import misk.web.HttpCall
import misk.web.RealChain
Expand Down Expand Up @@ -40,7 +40,7 @@ internal fun WebAction.asChain(
} else {
// Handle suspending invocation, this includes building out the context to propagate MDC
// and action scope.
val context = DynamicMdcContext() +
val context = MDCContext() +
if (scope.inScope()) {
scope.asContextElement()
} else {
Expand Down
48 changes: 0 additions & 48 deletions misk/src/test/kotlin/misk/logging/DynamicMdcContextTest.kt

This file was deleted.

79 changes: 79 additions & 0 deletions misk/src/test/kotlin/misk/logging/ScopedMdcKtTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package misk.logging

import jakarta.inject.Inject
import kotlinx.coroutines.delay
import kotlinx.coroutines.slf4j.MDCContext
import kotlinx.coroutines.test.runTest
import misk.MiskTestingServiceModule
import misk.testing.MiskTest
import misk.testing.MiskTestModule
import org.junit.jupiter.api.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull
import misk.logging.coroutines.withMdc as withMdcCoroutines

@MiskTest(startService = false)
internal class ScopedMdcKtTest {
@MiskTestModule
val module = MiskTestingServiceModule()

@Inject
lateinit var mdc: Mdc

@Test
fun `test withMdc in a coroutine for key value pairs`() = runTest(MDCContext()) {
val tags = (1..3).map { "key$it" to "value$it" }.toTypedArray()
mdc.withMdcCoroutines(*tags) {
tags.assertTags()
delay(100)
tags.assertTags()
}
tags.asserMissingTags()
}

@Test
fun `test withMdc for key value pairs`() {
val tags = (1..3).map { "key$it" to "value$it" }.toTypedArray()
mdc.withMdc(*tags) {
tags.assertTags()
}
tags.asserMissingTags()
}

@Test
fun `test withMdc in a coroutine for key value pair overrides`() = runTest(MDCContext()) {
val tags = (1..3).map { "key$it" to "value$it" }.toTypedArray()
mdc.withMdcCoroutines(*tags) {
tags.assertTags()
delay(100)
tags.assertTags()
val updatedTags = tags.map { if (it.first == "key1"){it.first to it.second+"00"} else {it} }.toTypedArray()
mdc.withMdcCoroutines(*updatedTags) {
updatedTags.assertTags()
delay(100)
updatedTags.assertTags()
}
tags.forEach { it.asserTag() }
}
tags.asserMissingTags()
}

@Test
fun `test withMdc for key value pair overrides`() {
val tags = (1..3).map { "key$it" to "value$it" }.toTypedArray()
mdc.withMdc(*tags) {
tags.assertTags()
val updatedTags = tags.map { if (it.first == "key1"){it.first to it.second+"00"} else {it} }.toTypedArray()
mdc.withMdc(*updatedTags) {
updatedTags.assertTags()
}
tags.forEach { it.asserTag() }
}
tags.asserMissingTags()
}

fun Pair<String, String>.asserTag() = assertEquals(second, mdc.get(first))
fun Array<Pair<String, String>>.assertTags() = forEach { it.asserTag() }
fun Pair<String, String>.asserMissingTag() = assertNull( mdc.get(first))
fun Array<Pair<String, String>>.asserMissingTags() = forEach { it.asserMissingTag() }
}

0 comments on commit b9426c8

Please sign in to comment.