diff --git a/libappfuse/FuseBuffer.cc b/libappfuse/FuseBuffer.cc index 8fb2dbcc5..13cfc88ec 100644 --- a/libappfuse/FuseBuffer.cc +++ b/libappfuse/FuseBuffer.cc @@ -23,77 +23,132 @@ #include #include +#include + #include #include #include namespace android { namespace fuse { - -static_assert( - std::is_standard_layout::value, - "FuseBuffer must be standard layout union."); +namespace { template -bool FuseMessage::CheckHeaderLength(const char* name) const { - const auto& header = static_cast(this)->header; - if (header.len >= sizeof(header) && header.len <= sizeof(T)) { +bool CheckHeaderLength(const FuseMessage* self, const char* name) { + const auto& header = static_cast(self)->header; + if (header.len >= sizeof(header) && header.len <= sizeof(T)) { + return true; + } else { + LOG(ERROR) << "Invalid header length is found in " << name << ": " << header.len; + return false; + } +} + +template +ResultOrAgain ReadInternal(FuseMessage* self, int fd, int sockflag) { + char* const buf = reinterpret_cast(self); + const ssize_t result = sockflag ? TEMP_FAILURE_RETRY(recv(fd, buf, sizeof(T), sockflag)) + : TEMP_FAILURE_RETRY(read(fd, buf, sizeof(T))); + + switch (result) { + case 0: + // Expected EOF. + return ResultOrAgain::kFailure; + case -1: + if (errno == EAGAIN) { + return ResultOrAgain::kAgain; + } + PLOG(ERROR) << "Failed to read a FUSE message"; + return ResultOrAgain::kFailure; + } + + const auto& header = static_cast(self)->header; + if (result < static_cast(sizeof(header))) { + LOG(ERROR) << "Read bytes " << result << " are shorter than header size " << sizeof(header); + return ResultOrAgain::kFailure; + } + + if (!CheckHeaderLength(self, "Read")) { + return ResultOrAgain::kFailure; + } + + if (static_cast(result) != header.len) { + LOG(ERROR) << "Read bytes " << result << " are different from header.len " << header.len; + return ResultOrAgain::kFailure; + } + + return ResultOrAgain::kSuccess; +} + +template +ResultOrAgain WriteInternal(const FuseMessage* self, int fd, int sockflag) { + if (!CheckHeaderLength(self, "Write")) { + return ResultOrAgain::kFailure; + } + + const char* const buf = reinterpret_cast(self); + const auto& header = static_cast(self)->header; + const int result = sockflag ? TEMP_FAILURE_RETRY(send(fd, buf, header.len, sockflag)) + : TEMP_FAILURE_RETRY(write(fd, buf, header.len)); + + if (result == -1) { + if (errno == EAGAIN) { + return ResultOrAgain::kAgain; + } + PLOG(ERROR) << "Failed to write a FUSE message"; + return ResultOrAgain::kFailure; + } + + CHECK(static_cast(result) == header.len); + return ResultOrAgain::kSuccess; +} +} + +static_assert(std::is_standard_layout::value, + "FuseBuffer must be standard layout union."); + +bool SetupMessageSockets(base::unique_fd (*result)[2]) { + base::unique_fd fds[2]; + { + int raw_fds[2]; + if (socketpair(AF_UNIX, SOCK_SEQPACKET, 0, raw_fds) == -1) { + PLOG(ERROR) << "Failed to create sockets for proxy"; + return false; + } + fds[0].reset(raw_fds[0]); + fds[1].reset(raw_fds[1]); + } + + constexpr int kMaxMessageSize = sizeof(FuseBuffer); + if (setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0 || + setsockopt(fds[1], SOL_SOCKET, SO_SNDBUF, &kMaxMessageSize, sizeof(int)) != 0) { + PLOG(ERROR) << "Failed to update buffer size for socket"; + return false; + } + + (*result)[0] = std::move(fds[0]); + (*result)[1] = std::move(fds[1]); return true; - } else { - LOG(ERROR) << "Invalid header length is found in " << name << ": " << - header.len; - return false; - } } template bool FuseMessage::Read(int fd) { - char* const buf = reinterpret_cast(this); - const ssize_t result = TEMP_FAILURE_RETRY(::read(fd, buf, sizeof(T))); - if (result < 0) { - PLOG(ERROR) << "Failed to read a FUSE message"; - return false; - } + return ReadInternal(this, fd, 0) == ResultOrAgain::kSuccess; +} - const auto& header = static_cast(this)->header; - if (result < static_cast(sizeof(header))) { - LOG(ERROR) << "Read bytes " << result << " are shorter than header size " << - sizeof(header); - return false; - } - - if (!CheckHeaderLength("Read")) { - return false; - } - - if (static_cast(result) > header.len) { - LOG(ERROR) << "Read bytes " << result << " are longer than header.len " << - header.len; - return false; - } - - if (!base::ReadFully(fd, buf + result, header.len - result)) { - PLOG(ERROR) << "ReadFully failed"; - return false; - } - - return true; +template +ResultOrAgain FuseMessage::ReadOrAgain(int fd) { + return ReadInternal(this, fd, MSG_DONTWAIT); } template bool FuseMessage::Write(int fd) const { - if (!CheckHeaderLength("Write")) { - return false; - } + return WriteInternal(this, fd, 0) == ResultOrAgain::kSuccess; +} - const char* const buf = reinterpret_cast(this); - const auto& header = static_cast(this)->header; - if (!base::WriteFully(fd, buf, header.len)) { - PLOG(ERROR) << "WriteFully failed"; - return false; - } - - return true; +template +ResultOrAgain FuseMessage::WriteOrAgain(int fd) const { + return WriteInternal(this, fd, MSG_DONTWAIT); } template class FuseMessage; diff --git a/libappfuse/include/libappfuse/FuseBuffer.h b/libappfuse/include/libappfuse/FuseBuffer.h index 7abd2fa40..fbb05d633 100644 --- a/libappfuse/include/libappfuse/FuseBuffer.h +++ b/libappfuse/include/libappfuse/FuseBuffer.h @@ -17,6 +17,7 @@ #ifndef ANDROID_LIBAPPFUSE_FUSEBUFFER_H_ #define ANDROID_LIBAPPFUSE_FUSEBUFFER_H_ +#include #include namespace android { @@ -28,12 +29,24 @@ constexpr size_t kFuseMaxWrite = 256 * 1024; constexpr size_t kFuseMaxRead = 128 * 1024; constexpr int32_t kFuseSuccess = 0; +// Setup sockets to transfer FuseMessage. +bool SetupMessageSockets(base::unique_fd (*sockets)[2]); + +enum class ResultOrAgain { + kSuccess, + kFailure, + kAgain, +}; + template class FuseMessage { public: bool Read(int fd); bool Write(int fd) const; - private: + ResultOrAgain ReadOrAgain(int fd); + ResultOrAgain WriteOrAgain(int fd) const; + +private: bool CheckHeaderLength(const char* name) const; }; @@ -54,7 +67,7 @@ struct FuseRequest : public FuseMessage { // for FUSE_READ fuse_read_in read_in; // for FUSE_LOOKUP - char lookup_name[0]; + char lookup_name[kFuseMaxWrite]; }; void Reset(uint32_t data_length, uint32_t opcode, uint64_t unique); }; diff --git a/libappfuse/tests/FuseAppLoopTest.cc b/libappfuse/tests/FuseAppLoopTest.cc index 25906cf1c..64dd81330 100644 --- a/libappfuse/tests/FuseAppLoopTest.cc +++ b/libappfuse/tests/FuseAppLoopTest.cc @@ -109,10 +109,7 @@ class FuseAppLoopTest : public ::testing::Test { void SetUp() override { base::SetMinimumLogSeverity(base::VERBOSE); - int sockets[2]; - ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, sockets)); - sockets_[0].reset(sockets[0]); - sockets_[1].reset(sockets[1]); + ASSERT_TRUE(SetupMessageSockets(&sockets_)); thread_ = std::thread([this] { StartFuseAppLoop(sockets_[1].release(), &callback_); }); diff --git a/libappfuse/tests/FuseBridgeLoopTest.cc b/libappfuse/tests/FuseBridgeLoopTest.cc index e74d9e700..b4c1efb01 100644 --- a/libappfuse/tests/FuseBridgeLoopTest.cc +++ b/libappfuse/tests/FuseBridgeLoopTest.cc @@ -50,15 +50,8 @@ class FuseBridgeLoopTest : public ::testing::Test { void SetUp() override { base::SetMinimumLogSeverity(base::VERBOSE); - int dev_sockets[2]; - int proxy_sockets[2]; - ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, dev_sockets)); - ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_SEQPACKET, 0, proxy_sockets)); - dev_sockets_[0].reset(dev_sockets[0]); - dev_sockets_[1].reset(dev_sockets[1]); - proxy_sockets_[0].reset(proxy_sockets[0]); - proxy_sockets_[1].reset(proxy_sockets[1]); - + ASSERT_TRUE(SetupMessageSockets(&dev_sockets_)); + ASSERT_TRUE(SetupMessageSockets(&proxy_sockets_)); thread_ = std::thread([this] { StartFuseBridgeLoop( dev_sockets_[1].release(), proxy_sockets_[0].release(), &callback_); diff --git a/libappfuse/tests/FuseBufferTest.cc b/libappfuse/tests/FuseBufferTest.cc index 1a1abd57e..ade34acc1 100644 --- a/libappfuse/tests/FuseBufferTest.cc +++ b/libappfuse/tests/FuseBufferTest.cc @@ -112,30 +112,6 @@ TEST(FuseMessageTest, Write_TooShort) { TestWriteInvalidLength(sizeof(fuse_in_header) - 1); } -TEST(FuseMessageTest, ShortWriteAndRead) { - int raw_fds[2]; - ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, raw_fds)); - - android::base::unique_fd fds[2]; - fds[0].reset(raw_fds[0]); - fds[1].reset(raw_fds[1]); - - const int send_buffer_size = 1024; - ASSERT_EQ(0, setsockopt(fds[0], SOL_SOCKET, SO_SNDBUF, &send_buffer_size, - sizeof(int))); - - bool succeed = false; - const int sender_fd = fds[0].get(); - std::thread thread([sender_fd, &succeed] { - FuseRequest request; - request.header.len = 1024 * 4; - succeed = request.Write(sender_fd); - }); - thread.detach(); - FuseRequest request; - ASSERT_TRUE(request.Read(fds[1])); -} - TEST(FuseResponseTest, Reset) { FuseResponse response; // Write 1 to the first ten bytes. @@ -211,5 +187,29 @@ TEST(FuseBufferTest, HandleNotImpl) { EXPECT_EQ(-ENOSYS, buffer.response.header.error); } +TEST(SetupMessageSocketsTest, Stress) { + constexpr int kCount = 1000; + + FuseRequest request; + request.header.len = sizeof(FuseRequest); + + base::unique_fd fds[2]; + SetupMessageSockets(&fds); + + std::thread thread([&fds] { + FuseRequest request; + for (int i = 0; i < kCount; ++i) { + ASSERT_TRUE(request.Read(fds[1])); + usleep(1000); + } + }); + + for (int i = 0; i < kCount; ++i) { + ASSERT_TRUE(request.Write(fds[0])); + } + + thread.join(); +} + } // namespace fuse } // namespace android