Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Unverified Commit a8acbff3 authored by Ricki Hirner's avatar Ricki Hirner Committed by GitHub
Browse files

Fix repeated sending of POST/PUT requests (#157)

* Add Ktor client auth dependency and test for repeated POST requests on 401

* Fix repeated sending of POST/PUT using OutgoingContent

- Replace `provideBody` and `mimeType` parameters with single `body: OutgoingContent` parameter
- Update tests to use `TextContent` and streaming `OutgoingContent.ReadChannelContent`
- Remove redundant `contentType()` calls as it's handled by OutgoingContent

* Fix streaming body consumption in repeated POST/PUT requests

- Update documentation to clarify that `OutgoingContent.ReadChannelContent.readFrom` must return an unconsumed channel
- Add test to verify that a new channel is created for each request
- Fix test case to use correct sample text

* Add WWW-Authenticate header to 401 response in auth test
parent 59ab44eb
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -63,6 +63,7 @@ dependencies {

    testImplementation(libs.junit4)
    testImplementation(libs.kotlin.coroutines.test)
    testImplementation(libs.ktor.client.auth)
    testImplementation(libs.ktor.client.mock)
    testImplementation(libs.okhttp.mockwebserver)
}
 No newline at end of file
+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ ktor = "3.4.1"
[libraries]
junit4 = { module = "junit:junit", version.ref = "junit4" }
kotlin-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version.ref = "kotlin-coroutines" }
ktor-client-auth = { module = "io.ktor:ktor-client-auth", version.ref = "ktor" }
ktor-client-core = { module = "io.ktor:ktor-client-core", version.ref = "ktor" }
ktor-client-encoding = { module = "io.ktor:ktor-client-encoding", version.ref = "ktor" }
ktor-client-mock = { module = "io.ktor:ktor-client-mock", version.ref = "ktor" }
+9 −12
Original line number Diff line number Diff line
@@ -42,6 +42,7 @@ import io.ktor.http.HttpMethod
import io.ktor.http.HttpStatusCode
import io.ktor.http.URLBuilder
import io.ktor.http.Url
import io.ktor.http.content.OutgoingContent
import io.ktor.http.contentType
import io.ktor.http.isSecure
import io.ktor.http.isSuccess
@@ -394,14 +395,13 @@ open class DavResource(
     *
     * Follows up to [MAX_REDIRECTS] redirects.
     *
     * @param provideBody       resource body to upload (unconsumed, may be called multiple times on redirects)
     * @param mimeType          content type of resource body
     * @param body              resource body to upload (use [OutgoingContent.ReadChannelContent]
     * for streaming; ensure that every [OutgoingContent.ReadChannelContent.readFrom] returns an unconsumed channel)
     * @param additionalHeaders additional headers to send
     * @param callback          called with server response on success
     */
    suspend fun post(
        provideBody: () -> ByteReadChannel,
        mimeType: ContentType,
        body: OutgoingContent,
        additionalHeaders: Headers? = null,
        callback: ResponseCallback
    ) {
@@ -410,8 +410,7 @@ open class DavResource(
                if (additionalHeaders != null)
                    headers.appendAll(additionalHeaders)

                contentType(mimeType)
                setBody(provideBody())
                setBody(body)
            }
        }) { response ->
            checkStatus(response)
@@ -424,8 +423,8 @@ open class DavResource(
     *
     * Follows up to [MAX_REDIRECTS] redirects.
     *
     * @param provideBody       resource body to upload (unconsumed, may be called multiple times on redirects)
     * @param mimeType          content type of resource body
     * @param body              resource body to upload (use [OutgoingContent.ReadChannelContent]
     * for streaming; ensure that every [OutgoingContent.ReadChannelContent.readFrom] returns an unconsumed channel)
     * @param additionalHeaders additional headers to send (like [HttpHeaders.IfNoneMatch] to prevent overwriting)
     * @param callback          called with server response on success
     *
@@ -434,8 +433,7 @@ open class DavResource(
     * @throws DavException on HTTPS -> HTTP redirect
     */
    suspend fun put(
        provideBody: () -> ByteReadChannel,
        mimeType: ContentType,
        body: OutgoingContent,
        additionalHeaders: Headers? = null,
        callback: ResponseCallback
    ) {
@@ -444,8 +442,7 @@ open class DavResource(
                if (additionalHeaders != null)
                    headers.appendAll(additionalHeaders)

                contentType(mimeType)
                setBody(provideBody())
                setBody(body)
            }
        }) { response ->
            checkStatus(response)
+66 −1
Original line number Diff line number Diff line
@@ -18,6 +18,10 @@ import at.bitfire.dav4jvm.property.webdav.WebDAV
import io.ktor.client.HttpClient
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.client.engine.mock.toByteArray
import io.ktor.client.plugins.auth.Auth
import io.ktor.client.plugins.auth.providers.BasicAuthCredentials
import io.ktor.client.plugins.auth.providers.basic
import io.ktor.client.statement.request
import io.ktor.http.ContentType
import io.ktor.http.HttpHeaders
@@ -25,6 +29,8 @@ import io.ktor.http.HttpMethod
import io.ktor.http.HttpStatusCode
import io.ktor.http.URLBuilder
import io.ktor.http.Url
import io.ktor.http.content.OutgoingContent
import io.ktor.http.content.TextContent
import io.ktor.http.headersOf
import io.ktor.http.takeFrom
import io.ktor.http.withCharset
@@ -270,8 +276,9 @@ class DavCollectionTest {
        val dav = DavCollection(httpClient, sampleUrl)

        var called = false
        dav.post({ ByteReadChannel(sampleText) }, ContentType.Text.Plain) { response ->
        dav.post(TextContent(sampleText, ContentType.Text.Plain)) { response ->
            assertEquals(HttpMethod.Post, response.request.method)
            assertEquals(ContentType.Text.Plain, response.request.content.contentType)
            assertEquals(HttpStatusCode.Created, response.status)
            assertEquals(response.request.url, dav.location)
            called = true
@@ -279,4 +286,62 @@ class DavCollectionTest {
        assertTrue(called)
    }

    @Test
    fun testPostStreamingRepeatedlyBecause401() = runTest {
        var requestCount = 0

        val mockEngine = MockEngine { request ->
            requestCount++

            // Verify that request body is always sent – https://github.com/bitfireAT/dav4jvm/issues/156
            assertEquals(sampleText, request.body.toByteArray().toString(Charsets.UTF_8))

            if (requestCount == 1) {
                // First request: respond with 401 to indicate that request shall be sent again
                respond(
                    content = "Send Auth",
                    status = HttpStatusCode.Unauthorized,
                    headers = headersOf(HttpHeaders.WWWAuthenticate, "Basic realm=\"test\"")
                )
            } else {
                // Second request: respond with success
                respond(
                    content = sampleText,
                    status = HttpStatusCode.Created,
                    headers = headersOf(HttpHeaders.ContentType, ContentType.Text.Plain.toString())
                )
            }
        }
        val httpClient = HttpClient(mockEngine) {
            install(Auth) {     // authentication plugin retries request when receiving 401
                basic {
                    credentials {
                        BasicAuthCredentials("test", "test")
                    }
                }
            }
        }
        val dav = DavCollection(httpClient, sampleUrl)

        var called = false
        var channelsCreated = 0     // ByteReadChannel is created anew for every request
        val streamingBody = object : OutgoingContent.ReadChannelContent() {
            override fun readFrom(): ByteReadChannel {
                channelsCreated++
                return ByteReadChannel(sampleText)
            }
            override val contentType = ContentType.Text.Plain
        }
        dav.post(streamingBody) { response ->
            assertEquals(HttpMethod.Post, response.request.method)
            assertEquals(ContentType.Text.Plain, response.request.content.contentType)
            assertEquals(HttpStatusCode.Created, response.status)
            assertEquals(response.request.url, dav.location)
            called = true
        }
        assertTrue(called)
        assertEquals(2, channelsCreated)
        assertEquals(2, requestCount)
    }

}
 No newline at end of file
+6 −21
Original line number Diff line number Diff line
@@ -41,7 +41,6 @@ import io.ktor.http.fullPath
import io.ktor.http.headersOf
import io.ktor.http.takeFrom
import io.ktor.http.withCharset
import io.ktor.utils.io.ByteReadChannel
import kotlinx.coroutines.test.runTest
import org.junit.Assert.assertEquals
import org.junit.Assert.assertFalse
@@ -367,10 +366,7 @@ class DavResourceTest {

        // 200 OK
        var called = false
        dav.post(
            provideBody = { ByteReadChannel("body") },
            mimeType = ContentType.parse("application/x-test-result")
        ) { response ->
        dav.post(TextContent("body", ContentType.parse("application/x-test-result"))) { response ->
            called = true
            assertEquals(sampleText, response.bodyAsText())

@@ -436,10 +432,7 @@ class DavResourceTest {
        val dav = DavResource(httpClient, sampleUrl)

        var called = false
        dav.post(
            provideBody = { ByteReadChannel("body") },
            mimeType = ContentType.parse("application/x-test-result")
        ) { response ->
        dav.post(TextContent("body", ContentType.parse("application/x-test-result"))) { response ->
            called = true
            assertEquals(sampleText, response.bodyAsText())
            val eTag = GetETag(response.headers[HttpHeaders.ETag])
@@ -462,10 +455,7 @@ class DavResourceTest {
        val dav = DavResource(httpClient, sampleUrl)

        var called = false
        dav.post(
            provideBody = { ByteReadChannel("body") },
            mimeType = ContentType.Text.Plain
        ) { response ->
        dav.post(TextContent("body", ContentType.Text.Plain)) { response ->
            called = true
            assertNull(response.headers[HttpHeaders.ETag])
        }
@@ -1060,10 +1050,7 @@ class DavResourceTest {
        val dav = DavResource(httpClient, sampleUrl)

        var called = false
        dav.put(
            provideBody = { ByteReadChannel(sampleText) },
            mimeType = ContentType.Text.Plain
        ) { response ->
        dav.put(TextContent(sampleText, ContentType.Text.Plain)) { response ->
            called = true
            val eTag = GetETag.fromHttpResponse(response)!!
            assertEquals("Weak PUT ETag", eTag.eTag)
@@ -1095,8 +1082,7 @@ class DavResourceTest {

        var called = false
        dav.put(
            { ByteReadChannel(sampleText) },
            ContentType.Text.Plain,
            TextContent(sampleText, ContentType.Text.Plain),
            headersOf(HttpHeaders.IfNoneMatch, "*")
        ) { response ->
            called = true
@@ -1121,8 +1107,7 @@ class DavResourceTest {
        var called = false
        try {
            dav.put(
                { ByteReadChannel(sampleText) },
                ContentType.Text.Plain,
                TextContent(sampleText, ContentType.Text.Plain),
                headersOf(HttpHeaders.IfMatch, "\"ExistingETag\"")
            ) {
                called = true