From 829ba96923ae6a00e19e8a9569d8bbe80ceb82a2 Mon Sep 17 00:00:00 2001 From: Hossein Naderi Date: Mon, 3 Jul 2023 22:29:09 +0330 Subject: [PATCH] Implemented CurlMultiSocket This implementation uses curl multi socket drive, and uses FileDescriptorPoller from cats effect. It allows using this library alongside other cats effect libraries. --- .../org/http4s/curl/http/CurlClient.scala | 3 + .../org/http4s/curl/http/CurlRequest.scala | 65 ++++ .../org/http4s/curl/unsafe/CURLMcode.scala | 30 ++ .../org/http4s/curl/unsafe/CURLcode.scala | 30 ++ .../curl/unsafe/CurlExecutorScheduler.scala | 2 + .../curl/unsafe/CurlMultiSocketImpl.scala | 286 ++++++++++++++++++ .../org/http4s/curl/unsafe/libcurl.scala | 14 +- .../http4s/curl/websocket/Connection.scala | 37 +++ .../http4s/curl/websocket/CurlWSClient.scala | 49 +++ .../scala/CurlClientMultiSocketSuite.scala | 85 ++++++ .../scala/CurlWSClientMultiSocketSuite.scala | 90 ++++++ 11 files changed, 681 insertions(+), 10 deletions(-) create mode 100644 curl/src/main/scala/org/http4s/curl/unsafe/CURLMcode.scala create mode 100644 curl/src/main/scala/org/http4s/curl/unsafe/CURLcode.scala create mode 100644 curl/src/main/scala/org/http4s/curl/unsafe/CurlMultiSocketImpl.scala create mode 100644 tests/http/src/test/scala/CurlClientMultiSocketSuite.scala create mode 100644 tests/websocket/src/test/scala/CurlWSClientMultiSocketSuite.scala diff --git a/curl/src/main/scala/org/http4s/curl/http/CurlClient.scala b/curl/src/main/scala/org/http4s/curl/http/CurlClient.scala index 6412220..e3bd98c 100644 --- a/curl/src/main/scala/org/http4s/curl/http/CurlClient.scala +++ b/curl/src/main/scala/org/http4s/curl/http/CurlClient.scala @@ -19,10 +19,13 @@ package org.http4s.curl.http import cats.effect._ import org.http4s.client.Client import org.http4s.curl.unsafe.CurlExecutorScheduler +import org.http4s.curl.unsafe.CurlMultiSocket private[curl] object CurlClient { def apply(ec: CurlExecutorScheduler): Client[IO] = Client(CurlRequest(ec, _)) + def multiSocket(ms: CurlMultiSocket): Client[IO] = Client(CurlRequest.applyMultiSocket(ms, _)) + def get: IO[Client[IO]] = IO.executionContext.flatMap { case ec: CurlExecutorScheduler => IO.pure(apply(ec)) case _ => IO.raiseError(new RuntimeException("Not running on CurlExecutorScheduler")) diff --git a/curl/src/main/scala/org/http4s/curl/http/CurlRequest.scala b/curl/src/main/scala/org/http4s/curl/http/CurlRequest.scala index cd5c7a7..5480f64 100644 --- a/curl/src/main/scala/org/http4s/curl/http/CurlRequest.scala +++ b/curl/src/main/scala/org/http4s/curl/http/CurlRequest.scala @@ -22,6 +22,7 @@ import org.http4s.Response import org.http4s.curl.internal.Utils import org.http4s.curl.internal._ import org.http4s.curl.unsafe.CurlExecutorScheduler +import org.http4s.curl.unsafe.CurlMultiSocket private[curl] object CurlRequest { private def setup( @@ -78,6 +79,57 @@ private[curl] object CurlRequest { ) ) + private def setup( + handle: CurlEasy, + send: RequestSend, + recv: RequestRecv, + req: Request[IO], + ): Resource[IO, Unit] = + Utils.newZone.flatMap(implicit zone => + CurlSList().evalMap(headers => + IO { + // TODO add in options + // handle.setVerbose(true) + + import org.http4s.curl.unsafe.libcurl_const + import scala.scalanative.unsafe._ + import org.http4s.Header + import org.http4s.HttpVersion + import org.typelevel.ci._ + + handle.setCustomRequest(toCString(req.method.renderString)) + + handle.setUpload(true) + + handle.setUrl(toCString(req.uri.renderString)) + + val httpVersion = req.httpVersion match { + case HttpVersion.`HTTP/1.0` => libcurl_const.CURL_HTTP_VERSION_1_0 + case HttpVersion.`HTTP/1.1` => libcurl_const.CURL_HTTP_VERSION_1_1 + case HttpVersion.`HTTP/2` => libcurl_const.CURL_HTTP_VERSION_2 + case HttpVersion.`HTTP/3` => libcurl_const.CURL_HTTP_VERSION_3 + case _ => libcurl_const.CURL_HTTP_VERSION_NONE + } + handle.setHttpVersion(httpVersion) + + req.headers // curl adds these headers automatically, so we explicitly disable them + .transform(Header.Raw(ci"Expect", "") :: Header.Raw(ci"Transfer-Encoding", "") :: _) + .foreach(header => headers.append(header.toString)) + + handle.setHttpHeader(headers.toPtr) + + handle.setReadData(Utils.toPtr(send)) + handle.setReadFunction(RequestSend.readCallback(_, _, _, _)) + + handle.setHeaderData(Utils.toPtr(recv)) + handle.setHeaderFunction(RequestRecv.headerCallback(_, _, _, _)) + + handle.setWriteData(Utils.toPtr(recv)) + handle.setWriteFunction(RequestRecv.writeCallback(_, _, _, _)) + } + ) + ) + def apply(ec: CurlExecutorScheduler, req: Request[IO]): Resource[IO, Response[IO]] = for { gc <- GCRoot() handle <- CurlEasy() @@ -89,4 +141,17 @@ private[curl] object CurlRequest { _ <- req.body.through(send.pipe).compile.drain.background resp <- recv.response() } yield resp + + def applyMultiSocket(ms: CurlMultiSocket, req: Request[IO]): Resource[IO, Response[IO]] = for { + gc <- GCRoot() + handle <- CurlEasy() + flow <- FlowControl(handle) + send <- RequestSend(flow) + recv <- RequestRecv(flow) + _ <- gc.add(send, recv) + _ <- setup(handle, send, recv, req) + _ <- ms.addHandlerTerminating(handle, recv.onTerminated).toResource + _ <- req.body.through(send.pipe).compile.drain.background + resp <- recv.response() + } yield resp } diff --git a/curl/src/main/scala/org/http4s/curl/unsafe/CURLMcode.scala b/curl/src/main/scala/org/http4s/curl/unsafe/CURLMcode.scala new file mode 100644 index 0000000..cbff814 --- /dev/null +++ b/curl/src/main/scala/org/http4s/curl/unsafe/CURLMcode.scala @@ -0,0 +1,30 @@ +/* + * Copyright 2022 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.curl.unsafe + +import org.http4s.curl.CurlError + +import scala.scalanative.unsafe._ + +final private[curl] case class CURLMcode(value: CInt) extends AnyVal { + @inline def isOk: Boolean = value == 0 + @inline def isError: Boolean = value != 0 + @inline def throwOnError: Unit = + if (isError) { + throw CurlError.fromMCode(this) + } +} diff --git a/curl/src/main/scala/org/http4s/curl/unsafe/CURLcode.scala b/curl/src/main/scala/org/http4s/curl/unsafe/CURLcode.scala new file mode 100644 index 0000000..b10fa65 --- /dev/null +++ b/curl/src/main/scala/org/http4s/curl/unsafe/CURLcode.scala @@ -0,0 +1,30 @@ +/* + * Copyright 2022 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.curl.unsafe + +import org.http4s.curl.CurlError + +import scala.scalanative.unsafe._ + +final private[curl] case class CURLcode(value: CInt) extends AnyVal { + @inline def isOk: Boolean = value == 0 + @inline def isError: Boolean = value != 0 + @inline def throwOnError: Unit = + if (isError) { + throw CurlError.fromCode(this) + } +} diff --git a/curl/src/main/scala/org/http4s/curl/unsafe/CurlExecutorScheduler.scala b/curl/src/main/scala/org/http4s/curl/unsafe/CurlExecutorScheduler.scala index 27df139..02b4d65 100644 --- a/curl/src/main/scala/org/http4s/curl/unsafe/CurlExecutorScheduler.scala +++ b/curl/src/main/scala/org/http4s/curl/unsafe/CurlExecutorScheduler.scala @@ -21,11 +21,13 @@ import cats.effect.kernel.Resource import cats.effect.unsafe.PollingExecutorScheduler import org.http4s.curl.CurlError +import scala.annotation.nowarn import scala.collection.mutable import scala.concurrent.duration.Duration import scala.scalanative.unsafe._ import scala.scalanative.unsigned._ +@nowarn final private[curl] class CurlExecutorScheduler(multiHandle: Ptr[libcurl.CURLM], pollEvery: Int) extends PollingExecutorScheduler(pollEvery) { diff --git a/curl/src/main/scala/org/http4s/curl/unsafe/CurlMultiSocketImpl.scala b/curl/src/main/scala/org/http4s/curl/unsafe/CurlMultiSocketImpl.scala new file mode 100644 index 0000000..8e9293a --- /dev/null +++ b/curl/src/main/scala/org/http4s/curl/unsafe/CurlMultiSocketImpl.scala @@ -0,0 +1,286 @@ +/* + * Copyright 2022 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.curl.unsafe + +import cats.effect.FiberIO +import cats.effect.FileDescriptorPollHandle +import cats.effect.FileDescriptorPoller +import cats.effect.IO +import cats.effect.kernel.Ref +import cats.effect.kernel.Resource +import cats.effect.std.AtomicCell +import cats.effect.std.Dispatcher +import cats.syntax.all._ +import org.http4s.curl.CurlError +import org.http4s.curl.internal._ + +import scala.concurrent.duration._ +import scala.scalanative.unsafe._ + +private[curl] trait CurlMultiSocket { + def addHandlerTerminating(easy: CurlEasy, cb: Either[Throwable, Unit] => Unit): IO[Unit] + def addHandlerNonTerminating( + easy: CurlEasy, + cb: Either[Throwable, Unit] => Unit, + ): Resource[IO, Unit] +} + +private[curl] object CurlMultiSocket { + implicit private class OptFibOps(private val f: Option[FiberIO[?]]) extends AnyVal { + def cancel: IO[Unit] = f.fold(IO.unit)(_.cancel) + } + + private val getFDPoller = IO.pollers.flatMap( + _.collectFirst { case poller: FileDescriptorPoller => poller }.liftTo[IO]( + new RuntimeException("Installed PollingSystem does not provide a FileDescriptorPoller") + ) + ) + + private val newCurlMutli = Resource.make(IO { + val multiHandle = libcurl.curl_multi_init() + if (multiHandle == null) + throw new RuntimeException("curl_multi_init") + multiHandle + })(mhandle => + IO { + val code = libcurl.curl_multi_cleanup(mhandle) + if (code.isError) + throw CurlError.fromMCode(code) + } + ) + + private lazy val curlGlobalSetup = { + val initCode = libcurl.curl_global_init(2) + if (initCode.isError) + throw CurlError.fromCode(initCode) + } + + def apply(): Resource[IO, CurlMultiSocket] = for { + _ <- IO(curlGlobalSetup).toResource + handle <- newCurlMutli + fdPoller <- getFDPoller.toResource + disp <- Dispatcher.sequential[IO] + mapping <- AtomicCell[IO].of(Map.empty[libcurl.curl_socket_t, Monitoring]).toResource + timeout <- IO.ref[Option[FiberIO[Unit]]](None).toResource + cms = new CurlMultiSocketImpl(handle, fdPoller, mapping, disp, timeout) + _ <- setup(cms, handle).toResource + } yield cms + + private def setup(cms: CurlMultiSocketImpl, handle: Ptr[libcurl.CURLM]) = IO { + val data = Utils.toPtr(cms) + + libcurl + .curl_multi_setopt_timerdata( + handle, + libcurl_const.CURLMOPT_TIMERDATA, + data, + ) + .throwOnError + + libcurl + .curl_multi_setopt_socketdata( + handle, + libcurl_const.CURLMOPT_SOCKETDATA, + data, + ) + .throwOnError + + libcurl + .curl_multi_setopt_timerfunction( + handle, + libcurl_const.CURLMOPT_TIMERFUNCTION, + onTimeout(_, _, _), + ) + .throwOnError + + libcurl + .curl_multi_setopt_socketfunction( + handle, + libcurl_const.CURLMOPT_SOCKETFUNCTION, + onSocket(_, _, _, _, _), + ) + .throwOnError + + } *> cms.notifyTimeout + + final private case class Monitoring( + read: Option[FiberIO[Nothing]], + write: Option[FiberIO[Nothing]], + handle: FileDescriptorPollHandle, + unregister: IO[Unit], + ) { + def clean: IO[Unit] = IO.uncancelable(_ => read.cancel !> write.cancel !> unregister) + } + + private def onTimeout( + mutli: Ptr[libcurl.CURLM], + timeoutMs: CLong, + userdata: Ptr[Byte], + ): CInt = { + val d = Utils.fromPtr[CurlMultiSocketImpl](userdata) + + if (timeoutMs == -1) { + d.removeTimeout + } else { + d.setTimeout(timeoutMs) + } + 0 + } + + private def onSocket( + easy: Ptr[libcurl.CURL], + fd: libcurl.curl_socket_t, + what: Int, + userdata: Ptr[Byte], + socketdata: Ptr[Byte], + ): CInt = { + val d = Utils.fromPtr[CurlMultiSocketImpl](userdata) + + what match { + case libcurl_const.CURL_POLL_IN => d.addFD(fd, true, false) + case libcurl_const.CURL_POLL_OUT => d.addFD(fd, false, true) + case libcurl_const.CURL_POLL_INOUT => d.addFD(fd, true, true) + case libcurl_const.CURL_POLL_REMOVE => d.remove(fd) + case other => throw new UnknownError(s"Received unknown socket request: $other!") + } + + 0 + } + + final private class CurlMultiSocketImpl( + multiHandle: Ptr[libcurl.CURLM], + fdpoller: FileDescriptorPoller, + mapping: AtomicCell[IO, Map[libcurl.curl_socket_t, Monitoring]], + disp: Dispatcher[IO], + timeout: Ref[IO, Option[FiberIO[Unit]]], + ) extends CurlMultiSocket { + + private val callbacks = + scala.collection.mutable.Map[Ptr[libcurl.CURL], Either[Throwable, Unit] => Unit]() + + override def addHandlerTerminating( + easy: CurlEasy, + cb: Either[Throwable, Unit] => Unit, + ): IO[Unit] = IO { + libcurl.curl_multi_add_handle(multiHandle, easy.curl).throwOnError + callbacks(easy.curl) = cb + } + + override def addHandlerNonTerminating( + easy: CurlEasy, + cb: Either[Throwable, Unit] => Unit, + ): Resource[IO, Unit] = + Resource.make(addHandlerTerminating(easy, cb))(_ => + IO { + libcurl.curl_multi_remove_handle(multiHandle, easy.curl).throwOnError + callbacks.remove(easy.curl).foreach(_(Right(()))) + } + ) + + def addFD(fd: libcurl.curl_socket_t, read: Boolean, write: Boolean): Unit = + disp.unsafeRunAndForget { + + val newMonitor = fdpoller.registerFileDescriptor(fd, read, write).allocated.flatMap { + case (handle, unregister) => + ( + Option.when(read)(readLoop(fd, handle)).sequence, + Option.when(write)(writeLoop(fd, handle)).sequence, + ) + .mapN(Monitoring(_, _, handle, unregister)) + } + + IO.uncancelable(_ => + mapping.evalUpdate { m => + m.get(fd) match { + case None => + newMonitor.map(m.updated(fd, _)) + case Some(s: Monitoring) => + s.clean *> newMonitor.map(m.updated(fd, _)) + } + } + ) + } + + def remove(fd: libcurl.curl_socket_t): Unit = + disp.unsafeRunAndForget( + IO.uncancelable(_ => + mapping.evalUpdate { m => + m.get(fd) match { + case None => IO(m) + case Some(s) => s.clean.as(m - fd) + } + } + ) + ) + + def setTimeout(duration: Long): Unit = disp.unsafeRunAndForget( + (IO.sleep(duration.millis) *> notifyTimeout).start.flatMap(f => + timeout.getAndSet(Some(f)).flatMap(_.cancel) + ) + ) + + def removeTimeout: Unit = disp.unsafeRunAndForget( + timeout.getAndSet(None).flatMap(_.cancel) + ) + + def notifyTimeout: IO[Unit] = IO { + val running = stackalloc[Int]() + libcurl + .curl_multi_socket_action(multiHandle, libcurl_const.CURL_SOCKET_TIMEOUT, 0, running) + .throwOnError + + postAction + } + + private def postAction = while ({ + val msgsInQueue = stackalloc[CInt]() + val info = libcurl.curl_multi_info_read(multiHandle, msgsInQueue) + + if (info != null) { + val curMsg = libcurl.curl_CURLMsg_msg(info) + if (curMsg == libcurl_const.CURLMSG_DONE) { + val handle = libcurl.curl_CURLMsg_easy_handle(info) + callbacks.remove(handle).foreach { cb => + val result = libcurl.curl_CURLMsg_data_result(info) + cb( + if (result.isOk) Right(()) + else Left(CurlError.fromCode(result)) + ) + } + + val code = libcurl.curl_multi_remove_handle(multiHandle, handle) + if (code.isError) + throw CurlError.fromMCode(code) + } + true + } else false + }) {} + + private def action(fd: libcurl.curl_socket_t, ev: CInt) = IO { + val running = stackalloc[Int]() + libcurl.curl_multi_socket_action(multiHandle, fd, ev, running) + + postAction + + Left(()) + } + private def readLoop(fd: libcurl.curl_socket_t, p: FileDescriptorPollHandle) = + p.pollReadRec(())(_ => action(fd, libcurl_const.CURL_CSELECT_IN)).start + private def writeLoop(fd: libcurl.curl_socket_t, p: FileDescriptorPollHandle) = + p.pollWriteRec(())(_ => action(fd, libcurl_const.CURL_CSELECT_OUT)).start + } +} diff --git a/curl/src/main/scala/org/http4s/curl/unsafe/libcurl.scala b/curl/src/main/scala/org/http4s/curl/unsafe/libcurl.scala index fb9a698..976b5ea 100644 --- a/curl/src/main/scala/org/http4s/curl/unsafe/libcurl.scala +++ b/curl/src/main/scala/org/http4s/curl/unsafe/libcurl.scala @@ -128,6 +128,9 @@ private[curl] object libcurl_const { // websocket options flags final val CURLWS_RAW_MODE = 1 << 0 + final val CURL_SOCKET_BAD = -1 + final val CURL_SOCKET_TIMEOUT = CURL_SOCKET_BAD + final val CURL_CSELECT_IN = 0x01 final val CURL_CSELECT_OUT = 0x02 final val CURL_CSELECT_ERR = 0x04 @@ -139,15 +142,6 @@ private[curl] object libcurl_const { final val CURL_POLL_REMOVE = 4 } -final private[curl] case class CURLcode(value: CInt) extends AnyVal { - @inline def isOk: Boolean = value == 0 - @inline def isError: Boolean = value != 0 -} -final private[curl] case class CURLMcode(value: CInt) extends AnyVal { - @inline def isOk: Boolean = value == 0 - @inline def isError: Boolean = value != 0 -} - @link("curl") @extern private[curl] object libcurl { @@ -369,7 +363,7 @@ private[curl] object libcurl { @name("curl_multi_setopt") def curl_multi_setopt_timerdata( curl: Ptr[CURLM], - option: CURLMOPT_TIMERFUNCTION.type, + option: CURLMOPT_TIMERDATA.type, pointer: Ptr[Byte], ): CURLMcode = extern diff --git a/curl/src/main/scala/org/http4s/curl/websocket/Connection.scala b/curl/src/main/scala/org/http4s/curl/websocket/Connection.scala index bd9a5d7..692e8b8 100644 --- a/curl/src/main/scala/org/http4s/curl/websocket/Connection.scala +++ b/curl/src/main/scala/org/http4s/curl/websocket/Connection.scala @@ -29,6 +29,7 @@ import org.http4s.client.websocket._ import org.http4s.curl.internal.Utils import org.http4s.curl.internal._ import org.http4s.curl.unsafe.CurlExecutorScheduler +import org.http4s.curl.unsafe.CurlMultiSocket import org.http4s.curl.unsafe.libcurl import org.http4s.curl.unsafe.libcurl_const import scodec.bits.ByteVector @@ -242,6 +243,42 @@ private object Connection { _ <- estab.get.flatMap(IO.fromEither).toResource } yield con + def apply( + req: WSRequest, + ms: CurlMultiSocket, + recvBufferSize: Int, + pauseOn: Int, + resumeOn: Int, + verbose: Boolean, + ): Resource[IO, Connection] = for { + gc <- GCRoot() + dispatcher <- Dispatcher.sequential[IO] + recvQ <- Queue.bounded[IO, Option[WSFrame]](recvBufferSize).toResource + recv <- Ref[SyncIO].of(Option.empty[Receiving]).to[IO].toResource + estab <- IO.deferred[Either[Throwable, Unit]].toResource + handler <- CurlEasy() + brk <- Breaker( + handler, + capacity = recvBufferSize, + close = resumeOn, + open = pauseOn, + verbose, + ).toResource + con = new Connection( + handler, + recvQ, + recv, + dispatcher, + estab, + brk, + ) + _ <- setup(req, verbose)(con) + _ <- gc.add(con) + _ <- ms.addHandlerNonTerminating(handler, con.onTerminated) + // Wait until established or throw error + _ <- estab.get.flatMap(IO.fromEither).toResource + } yield con + } sealed private trait ReceivingType extends Serializable with Product diff --git a/curl/src/main/scala/org/http4s/curl/websocket/CurlWSClient.scala b/curl/src/main/scala/org/http4s/curl/websocket/CurlWSClient.scala index 276d977..b21089d 100644 --- a/curl/src/main/scala/org/http4s/curl/websocket/CurlWSClient.scala +++ b/curl/src/main/scala/org/http4s/curl/websocket/CurlWSClient.scala @@ -22,6 +22,7 @@ import cats.implicits._ import org.http4s.client.websocket.WSFrame._ import org.http4s.client.websocket._ import org.http4s.curl.unsafe.CurlExecutorScheduler +import org.http4s.curl.unsafe.CurlMultiSocket import org.http4s.curl.unsafe.CurlRuntime import org.http4s.curl.unsafe.libcurl_const import scodec.bits.ByteVector @@ -89,4 +90,52 @@ private[curl] object CurlWSClient { ) } } + + def apply( + ms: CurlMultiSocket, + recvBufferSize: Int = 100, + pauseOn: Int = 10, + resumeOn: Int = 30, + verbose: Boolean = false, + ): Option[WSClient[IO]] = + Option.when(CurlRuntime.isWebsocketAvailable && CurlRuntime.curlVersionNumber >= 0x75700) { + WSClient(true) { req => + Connection(req, ms, recvBufferSize, pauseOn, resumeOn, verbose) + .map(con => + new WSConnection[IO] { + override def send(wsf: WSFrame): IO[Unit] = wsf match { + case Close(_, _) => + val flags = libcurl_const.CURLWS_CLOSE + con.send(flags, ByteVector.empty) + case Ping(data) => + val flags = libcurl_const.CURLWS_PING + con.send(flags, data) + case Pong(data) => + val flags = libcurl_const.CURLWS_PONG + con.send(flags, data) + case Text(data, true) => + val flags = libcurl_const.CURLWS_TEXT + val bv = + ByteVector.encodeUtf8(data).getOrElse(throw InvalidTextFrame) + con.send(flags, bv) + case Binary(data, true) => + val flags = libcurl_const.CURLWS_BINARY + con.send(flags, data) + case _ => + // NOTE curl needs to know total amount of fragment size in first send + // and it is not compatible with current websocket interface in http4s + IO.raiseError(PartialFragmentFrame) + } + + override def sendMany[G[_]: Foldable, A <: WSFrame](wsfs: G[A]): IO[Unit] = + wsfs.traverse_(send) + + override def receive: IO[Option[WSFrame]] = con.receive + + override def subprotocol: Option[String] = None + + } + ) + } + } } diff --git a/tests/http/src/test/scala/CurlClientMultiSocketSuite.scala b/tests/http/src/test/scala/CurlClientMultiSocketSuite.scala new file mode 100644 index 0000000..39f95ea --- /dev/null +++ b/tests/http/src/test/scala/CurlClientMultiSocketSuite.scala @@ -0,0 +1,85 @@ +/* + * Copyright 2022 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.curl + +import cats.effect.IO +import cats.effect.SyncIO +import cats.effect.std.Random +import cats.syntax.all._ +import munit.CatsEffectSuite +import org.http4s.Method._ +import org.http4s.Request +import org.http4s.Status +import org.http4s.client.Client +import org.http4s.curl.unsafe.CurlMultiSocket +import org.http4s.syntax.all._ + +class CurlClientMultiSocketSuite extends CatsEffectSuite { + + val clientFixture: SyncIO[FunFixture[Client[IO]]] = ResourceFunFixture( + CurlMultiSocket().map(http.CurlClient.multiSocket(_)) + ) + + clientFixture.test("3 get echos") { client => + client + .expect[String]("http://localhost:8080/http") + .map(_.nonEmpty) + .assert + .parReplicateA_(3) + } + + clientFixture.test("500") { client => + client + .statusFromString("http://localhost:8080/http/500") + .assertEquals(Status.InternalServerError) + } + + clientFixture.test("unexpected") { client => + client + .expect[String]("http://localhost:8080/http/500") + .attempt + .map(_.isLeft) + .assert + } + + clientFixture.test("error") { client => + client.expect[String]("unsupported://server").intercept[CurlError] + } + + clientFixture.test("error") { client => + client.expect[String]("").intercept[CurlError] + } + + clientFixture.test("3 post echos") { client => + Random.scalaUtilRandom[IO].flatMap { random => + random + .nextString(8) + .flatMap { s => + val msg = s"hello postman $s" + client + .expect[String]( + Request[IO](POST, uri = uri"http://localhost:8080/http/echo").withEntity(msg) + ) + .flatTap(IO.println) + .map(_.contains(msg)) + .assert + } + .parReplicateA_(3) + } + } + +} diff --git a/tests/websocket/src/test/scala/CurlWSClientMultiSocketSuite.scala b/tests/websocket/src/test/scala/CurlWSClientMultiSocketSuite.scala new file mode 100644 index 0000000..8997f35 --- /dev/null +++ b/tests/websocket/src/test/scala/CurlWSClientMultiSocketSuite.scala @@ -0,0 +1,90 @@ +/* + * Copyright 2022 http4s.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.http4s.curl + +import cats.effect.IO +import cats.syntax.all._ +import munit.CatsEffectSuite +import org.http4s.client.websocket.WSFrame +import org.http4s.client.websocket.WSRequest +import org.http4s.curl.unsafe.CurlMultiSocket +import org.http4s.curl.websocket.CurlWSClient +import org.http4s.implicits._ + +class CurlWSClientMultiSocketSuite extends CatsEffectSuite { + + private val clientFixture = ResourceFunFixture( + CurlMultiSocket().evalMap( + CurlWSClient(_).liftTo[IO]( + new RuntimeException("websocket client is not supported in this environment") + ) + ) + ) + + clientFixture.test("websocket echo") { + val frames = List.range(1, 5).map(i => WSFrame.Text(s"text $i")) + + _.connectHighLevel(WSRequest(uri"ws://localhost:8080/ws/echo")) + .use(con => + con.receiveStream + .take(4) + .evalTap(IO.println) + .compile + .toList <& (frames.traverse(con.send(_))) + ) + .assertEquals(frames) + } + + clientFixture.test("websocket bounded") { + _.connectHighLevel(WSRequest(uri"ws://localhost:8080/ws/bounded")) + .use(con => + con.receiveStream + .evalTap(IO.println) + .compile + .toList + ) + .assertEquals(List(WSFrame.Text("everything"))) + } + + clientFixture.test("websocket closed") { + _.connectHighLevel(WSRequest(uri"ws://localhost:8080/ws/closed")) + .use(con => con.receiveStream.compile.toList) + .assertEquals(Nil) + .parReplicateA_(4) + } + + clientFixture.test("error") { client => + client + .connectHighLevel(WSRequest(uri"")) + .use_ + .intercept[CurlError] + } + + clientFixture.test("error") { client => + client + .connectHighLevel(WSRequest(uri"server")) + .use_ + .intercept[CurlError] + } + + clientFixture.test("invalid protocol") { client => + client + .connectHighLevel(WSRequest(uri"http://localhost:8080/http")) + .use_ + .intercept[IllegalArgumentException] + } +}