Loading fastboot/.clang-format +4 −1 Original line number Diff line number Diff line BasedOnStyle: Google AllowShortBlocksOnASingleLine: false AllowShortFunctionsOnASingleLine: false AllowShortFunctionsOnASingleLine: Inline ColumnLimit: 100 CommentPragmas: NOLINT:.* DerivePointerAlignment: false IndentWidth: 4 ContinuationIndentWidth: 8 ConstructorInitializerIndentWidth: 8 AccessModifierOffset: -2 PointerAlignment: Left TabWidth: 4 UseTab: Never Loading fastboot/Android.mk +13 −9 Original line number Diff line number Diff line Loading @@ -24,7 +24,15 @@ LOCAL_C_INCLUDES := \ $(LOCAL_PATH)/../../extras/ext4_utils \ $(LOCAL_PATH)/../../extras/f2fs_utils \ LOCAL_SRC_FILES := protocol.cpp engine.cpp bootimg_utils.cpp fastboot.cpp util.cpp fs.cpp LOCAL_SRC_FILES := \ bootimg_utils.cpp \ engine.cpp \ fastboot.cpp \ fs.cpp\ protocol.cpp \ socket.cpp \ util.cpp \ LOCAL_MODULE := fastboot LOCAL_MODULE_TAGS := debug LOCAL_MODULE_HOST_OS := darwin linux windows Loading @@ -33,15 +41,15 @@ LOCAL_CFLAGS += -Wall -Wextra -Werror -Wunreachable-code LOCAL_CFLAGS += -DFASTBOOT_REVISION='"$(fastboot_version)"' LOCAL_SRC_FILES_linux := socket_unix.cpp usb_linux.cpp util_linux.cpp LOCAL_SRC_FILES_linux := usb_linux.cpp util_linux.cpp LOCAL_STATIC_LIBRARIES_linux := libselinux LOCAL_SRC_FILES_darwin := socket_unix.cpp usb_osx.cpp util_osx.cpp LOCAL_SRC_FILES_darwin := usb_osx.cpp util_osx.cpp LOCAL_STATIC_LIBRARIES_darwin := libselinux LOCAL_LDLIBS_darwin := -lpthread -framework CoreFoundation -framework IOKit -framework Carbon LOCAL_CFLAGS_darwin := -Wno-unused-parameter LOCAL_SRC_FILES_windows := socket_windows.cpp usb_windows.cpp util_windows.cpp LOCAL_SRC_FILES_windows := usb_windows.cpp util_windows.cpp LOCAL_STATIC_LIBRARIES_windows := AdbWinApi LOCAL_REQUIRED_MODULES_windows := AdbWinApi LOCAL_LDLIBS_windows := -lws2_32 Loading Loading @@ -98,18 +106,14 @@ include $(CLEAR_VARS) LOCAL_MODULE := fastboot_test LOCAL_MODULE_HOST_OS := darwin linux windows LOCAL_SRC_FILES := socket_test.cpp LOCAL_SRC_FILES := socket.cpp socket_test.cpp LOCAL_STATIC_LIBRARIES := libbase libcutils LOCAL_CFLAGS += -Wall -Wextra -Werror -Wunreachable-code LOCAL_SRC_FILES_linux := socket_unix.cpp LOCAL_SRC_FILES_darwin := socket_unix.cpp LOCAL_LDLIBS_darwin := -lpthread -framework CoreFoundation -framework IOKit -framework Carbon LOCAL_CFLAGS_darwin := -Wno-unused-parameter LOCAL_SRC_FILES_windows := socket_windows.cpp LOCAL_LDLIBS_windows := -lws2_32 include $(BUILD_HOST_NATIVE_TEST) fastboot/socket_unix.cpp→fastboot/socket.cpp +212 −0 Original line number Diff line number Diff line Loading @@ -28,34 +28,74 @@ #include "socket.h" #include <errno.h> #include <netdb.h> #include <android-base/stringprintf.h> #include <cutils/sockets.h> class UnixUdpSocket : public UdpSocket { Socket::Socket(cutils_socket_t sock) : sock_(sock) {} Socket::~Socket() { Close(); } int Socket::Close() { int ret = 0; if (sock_ != INVALID_SOCKET) { ret = socket_close(sock_); sock_ = INVALID_SOCKET; } return ret; } bool Socket::SetReceiveTimeout(int timeout_ms) { if (timeout_ms != receive_timeout_ms_) { if (socket_set_receive_timeout(sock_, timeout_ms) == 0) { receive_timeout_ms_ = timeout_ms; return true; } return false; } return true; } ssize_t Socket::ReceiveAll(void* data, size_t length, int timeout_ms) { size_t total = 0; while (total < length) { ssize_t bytes = Receive(reinterpret_cast<char*>(data) + total, length - total, timeout_ms); if (bytes == -1) { if (total == 0) { return -1; } break; } total += bytes; } return total; } // Implements the Socket interface for UDP. class UdpSocket : public Socket { public: enum class Type { kClient, kServer }; UnixUdpSocket(int fd, Type type); ~UnixUdpSocket() override; UdpSocket(Type type, cutils_socket_t sock); ssize_t Send(const void* data, size_t length) override; ssize_t Receive(void* data, size_t length, int timeout_ms) override; int Close() override; private: int fd_; int receive_timeout_ms_ = 0; std::unique_ptr<sockaddr_storage> addr_; socklen_t addr_size_ = 0; DISALLOW_COPY_AND_ASSIGN(UnixUdpSocket); DISALLOW_COPY_AND_ASSIGN(UdpSocket); }; UnixUdpSocket::UnixUdpSocket(int fd, Type type) : fd_(fd) { // Only servers need to remember addresses; clients are connected to a server in NewUdpClient() UdpSocket::UdpSocket(Type type, cutils_socket_t sock) : Socket(sock) { // Only servers need to remember addresses; clients are connected to a server in NewClient() // so will send to that server without needing to specify the address again. if (type == Type::kServer) { addr_.reset(new sockaddr_storage); Loading @@ -64,26 +104,15 @@ UnixUdpSocket::UnixUdpSocket(int fd, Type type) : fd_(fd) { } } UnixUdpSocket::~UnixUdpSocket() { Close(); } ssize_t UnixUdpSocket::Send(const void* data, size_t length) { return TEMP_FAILURE_RETRY( sendto(fd_, data, length, 0, reinterpret_cast<sockaddr*>(addr_.get()), addr_size_)); ssize_t UdpSocket::Send(const void* data, size_t length) { return TEMP_FAILURE_RETRY(sendto(sock_, reinterpret_cast<const char*>(data), length, 0, reinterpret_cast<sockaddr*>(addr_.get()), addr_size_)); } ssize_t UnixUdpSocket::Receive(void* data, size_t length, int timeout_ms) { // Only set socket timeout if it's changed. if (receive_timeout_ms_ != timeout_ms) { timeval tv; tv.tv_sec = timeout_ms / 1000; tv.tv_usec = (timeout_ms % 1000) * 1000; if (setsockopt(fd_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) { ssize_t UdpSocket::Receive(void* data, size_t length, int timeout_ms) { if (!SetReceiveTimeout(timeout_ms)) { return -1; } receive_timeout_ms_ = timeout_ms; } socklen_t* addr_size_ptr = nullptr; if (addr_ != nullptr) { Loading @@ -91,41 +120,93 @@ ssize_t UnixUdpSocket::Receive(void* data, size_t length, int timeout_ms) { addr_size_ = sizeof(*addr_); addr_size_ptr = &addr_size_; } return TEMP_FAILURE_RETRY(recvfrom(fd_, data, length, 0, return TEMP_FAILURE_RETRY(recvfrom(sock_, reinterpret_cast<char*>(data), length, 0, reinterpret_cast<sockaddr*>(addr_.get()), addr_size_ptr)); } int UnixUdpSocket::Close() { int result = 0; if (fd_ != -1) { result = close(fd_); fd_ = -1; // Implements the Socket interface for TCP. class TcpSocket : public Socket { public: TcpSocket(cutils_socket_t sock) : Socket(sock) {} ssize_t Send(const void* data, size_t length) override; ssize_t Receive(void* data, size_t length, int timeout_ms) override; std::unique_ptr<Socket> Accept() override; private: DISALLOW_COPY_AND_ASSIGN(TcpSocket); }; ssize_t TcpSocket::Send(const void* data, size_t length) { size_t total = 0; while (total < length) { ssize_t bytes = TEMP_FAILURE_RETRY( send(sock_, reinterpret_cast<const char*>(data) + total, length - total, 0)); if (bytes == -1) { if (total == 0) { return -1; } break; } return result; total += bytes; } std::unique_ptr<UdpSocket> UdpSocket::NewUdpClient(const std::string& host, int port, std::string* error) { int getaddrinfo_error = 0; int fd = socket_network_client_timeout(host.c_str(), port, SOCK_DGRAM, 0, &getaddrinfo_error); if (fd == -1) { if (error) { *error = android::base::StringPrintf( "Failed to connect to %s:%d: %s", host.c_str(), port, getaddrinfo_error ? gai_strerror(getaddrinfo_error) : strerror(errno)); return total; } ssize_t TcpSocket::Receive(void* data, size_t length, int timeout_ms) { if (!SetReceiveTimeout(timeout_ms)) { return -1; } return TEMP_FAILURE_RETRY(recv(sock_, reinterpret_cast<char*>(data), length, 0)); } std::unique_ptr<Socket> TcpSocket::Accept() { cutils_socket_t handler = accept(sock_, nullptr, nullptr); if (handler == INVALID_SOCKET) { return nullptr; } return std::unique_ptr<TcpSocket>(new TcpSocket(handler)); } return std::unique_ptr<UdpSocket>(new UnixUdpSocket(fd, UnixUdpSocket::Type::kClient)); std::unique_ptr<Socket> Socket::NewClient(Protocol protocol, const std::string& host, int port, std::string* error) { if (protocol == Protocol::kUdp) { cutils_socket_t sock = socket_network_client(host.c_str(), port, SOCK_DGRAM); if (sock != INVALID_SOCKET) { return std::unique_ptr<UdpSocket>(new UdpSocket(UdpSocket::Type::kClient, sock)); } } else { cutils_socket_t sock = socket_network_client(host.c_str(), port, SOCK_STREAM); if (sock != INVALID_SOCKET) { return std::unique_ptr<TcpSocket>(new TcpSocket(sock)); } } std::unique_ptr<UdpSocket> UdpSocket::NewUdpServer(int port) { int fd = socket_inaddr_any_server(port, SOCK_DGRAM); if (fd == -1) { // This is just used in testing, no need for an error message. if (error) { *error = android::base::StringPrintf("Failed to connect to %s:%d", host.c_str(), port); } return nullptr; } return std::unique_ptr<UdpSocket>(new UnixUdpSocket(fd, UnixUdpSocket::Type::kServer)); // This functionality is currently only used by tests so we don't need any error messages. std::unique_ptr<Socket> Socket::NewServer(Protocol protocol, int port) { if (protocol == Protocol::kUdp) { cutils_socket_t sock = socket_inaddr_any_server(port, SOCK_DGRAM); if (sock != INVALID_SOCKET) { return std::unique_ptr<UdpSocket>(new UdpSocket(UdpSocket::Type::kServer, sock)); } } else { cutils_socket_t sock = socket_inaddr_any_server(port, SOCK_STREAM); if (sock != INVALID_SOCKET) { return std::unique_ptr<TcpSocket>(new TcpSocket(sock)); } } return nullptr; } fastboot/socket.h +35 −15 Original line number Diff line number Diff line Loading @@ -26,36 +26,41 @@ * SUCH DAMAGE. */ // This file provides a class interface for cross-platform UDP functionality. The main fastboot // This file provides a class interface for cross-platform socket functionality. The main fastboot // engine should not be using this interface directly, but instead should use a higher-level // interface that enforces the fastboot UDP protocol. // interface that enforces the fastboot protocol. #ifndef SOCKET_H_ #define SOCKET_H_ #include "android-base/macros.h" #include <memory> #include <string> // UdpSocket interface to be implemented for each platform. class UdpSocket { #include <android-base/macros.h> #include <cutils/sockets.h> // Socket interface to be implemented for each platform. class Socket { public: enum class Protocol { kTcp, kUdp }; // Creates a new client connection. Clients are connected to a specific hostname/port and can // only send to that destination. // On failure, |error| is filled (if non-null) and nullptr is returned. static std::unique_ptr<UdpSocket> NewUdpClient(const std::string& hostname, int port, std::string* error); static std::unique_ptr<Socket> NewClient(Protocol protocol, const std::string& hostname, int port, std::string* error); // Creates a new server bound to local |port|. This is only meant for testing, during normal // fastboot operation the device acts as the server. // The server saves sender addresses in Receive(), and uses the most recent address during // A UDP server saves sender addresses in Receive(), and uses the most recent address during // calls to Send(). static std::unique_ptr<UdpSocket> NewUdpServer(int port); static std::unique_ptr<Socket> NewServer(Protocol protocol, int port); virtual ~UdpSocket() = default; // Destructor closes the socket if it's open. virtual ~Socket(); // Sends |length| bytes of |data|. Returns the number of bytes actually sent or -1 on error. // Sends |length| bytes of |data|. For TCP sockets this will continue trying to send until all // bytes are transmitted. Returns the number of bytes actually sent or -1 on error. virtual ssize_t Send(const void* data, size_t length) = 0; // Waits up to |timeout_ms| to receive up to |length| bytes of data. |timout_ms| of 0 will Loading @@ -63,14 +68,29 @@ class UdpSocket { // errno will be set to EAGAIN or EWOULDBLOCK. virtual ssize_t Receive(void* data, size_t length, int timeout_ms) = 0; // Calls Receive() until exactly |length| bytes have been received or an error occurs. virtual ssize_t ReceiveAll(void* data, size_t length, int timeout_ms); // Closes the socket. Returns 0 on success, -1 on error. virtual int Close() = 0; virtual int Close(); // Accepts an incoming TCP connection. No effect for UDP sockets. Returns a new Socket // connected to the client on success, nullptr on failure. virtual std::unique_ptr<Socket> Accept() { return nullptr; } protected: // Protected constructor to force factory function use. UdpSocket() = default; Socket(cutils_socket_t sock); // Update the socket receive timeout if necessary. bool SetReceiveTimeout(int timeout_ms); cutils_socket_t sock_ = INVALID_SOCKET; private: int receive_timeout_ms_ = 0; DISALLOW_COPY_AND_ASSIGN(UdpSocket); DISALLOW_COPY_AND_ASSIGN(Socket); }; #endif // SOCKET_H_ fastboot/socket_test.cpp +71 −142 Original line number Diff line number Diff line Loading @@ -14,96 +14,86 @@ * limitations under the License. */ // Tests UDP functionality using loopback connections. Requires that kDefaultPort is available // Tests UDP functionality using loopback connections. Requires that kTestPort is available // for loopback communication on the host. These tests also assume that no UDP packets are lost, // which should be the case for loopback communication, but is not guaranteed. #include "socket.h" #include <errno.h> #include <time.h> #include <memory> #include <string> #include <vector> #include <gtest/gtest.h> enum { // This port must be available for loopback communication. kDefaultPort = 54321, kTestPort = 54321, // Don't wait forever in a unit test. kDefaultTimeoutMs = 3000, kTestTimeoutMs = 3000, }; static const char kReceiveStringError[] = "Error receiving string"; // Test fixture to provide some helper functions. Makes each test a little simpler since we can // just check a bool for socket creation and don't have to pass hostname or port information. class SocketTest : public ::testing::Test { protected: bool StartServer(int port = kDefaultPort) { server_ = UdpSocket::NewUdpServer(port); return server_ != nullptr; // Creates connected sockets |server| and |client|. Returns true on success. bool MakeConnectedSockets(Socket::Protocol protocol, std::unique_ptr<Socket>* server, std::unique_ptr<Socket>* client, const std::string hostname = "localhost", int port = kTestPort) { *server = Socket::NewServer(protocol, port); if (*server == nullptr) { ADD_FAILURE() << "Failed to create server."; return false; } bool StartClient(const std::string hostname = "localhost", int port = kDefaultPort) { client_ = UdpSocket::NewUdpClient(hostname, port, nullptr); return client_ != nullptr; *client = Socket::NewClient(protocol, hostname, port, nullptr); if (*client == nullptr) { ADD_FAILURE() << "Failed to create client."; return false; } bool StartClient2(const std::string hostname = "localhost", int port = kDefaultPort) { client2_ = UdpSocket::NewUdpClient(hostname, port, nullptr); return client2_ != nullptr; // TCP passes the client off to a new socket. if (protocol == Socket::Protocol::kTcp) { *server = (*server)->Accept(); if (*server == nullptr) { ADD_FAILURE() << "Failed to accept client connection."; return false; } std::unique_ptr<UdpSocket> server_, client_, client2_; }; // Sends a string over a UdpSocket. Returns true if the full string (without terminating char) // was sent. static bool SendString(UdpSocket* udp, const std::string& message) { return udp->Send(message.c_str(), message.length()) == static_cast<ssize_t>(message.length()); } // Receives a string from a UdpSocket. Returns the string, or kReceiveStringError on failure. static std::string ReceiveString(UdpSocket* udp, size_t receive_size = 128) { std::vector<char> buffer(receive_size); ssize_t result = udp->Receive(buffer.data(), buffer.size(), kDefaultTimeoutMs); if (result >= 0) { return std::string(buffer.data(), result); } return kReceiveStringError; return true; } // Calls Receive() on the UdpSocket with the given timeout. Returns true if the call timed out. static bool ReceiveTimeout(UdpSocket* udp, int timeout_ms) { char buffer[1]; // Sends a string over a Socket. Returns true if the full string (without terminating char) // was sent. static bool SendString(Socket* sock, const std::string& message) { return sock->Send(message.c_str(), message.length()) == static_cast<ssize_t>(message.length()); } errno = 0; return udp->Receive(buffer, 1, timeout_ms) == -1 && (errno == EAGAIN || errno == EWOULDBLOCK); // Receives a string from a Socket. Returns true if the full string (without terminating char) // was received. static bool ReceiveString(Socket* sock, const std::string& message) { std::string received(message.length(), '\0'); ssize_t bytes = sock->ReceiveAll(&received[0], received.length(), kTestTimeoutMs); return static_cast<size_t>(bytes) == received.length() && received == message; } // Tests sending packets client -> server, then server -> client. TEST_F(SocketTest, SendAndReceive) { ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient()); TEST(SocketTest, TestSendAndReceive) { std::unique_ptr<Socket> server, client; for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); EXPECT_TRUE(SendString(client_.get(), "foo")); EXPECT_EQ("foo", ReceiveString(server_.get())); EXPECT_TRUE(SendString(client.get(), "foo")); EXPECT_TRUE(ReceiveString(server.get(), "foo")); EXPECT_TRUE(SendString(server_.get(), "bar baz")); EXPECT_EQ("bar baz", ReceiveString(client_.get())); EXPECT_TRUE(SendString(server.get(), "bar baz")); EXPECT_TRUE(ReceiveString(client.get(), "bar baz")); } } // Tests sending and receiving large packets. TEST_F(SocketTest, LargePackets) { std::string message(512, '\0'); TEST(SocketTest, TestLargePackets) { std::string message(1024, '\0'); std::unique_ptr<Socket> server, client; ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient()); for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); // Run through the test a few times. for (int i = 0; i < 10; ++i) { Loading @@ -112,86 +102,25 @@ TEST_F(SocketTest, LargePackets) { message[j] = static_cast<char>(i + j); } EXPECT_TRUE(SendString(client_.get(), message)); EXPECT_EQ(message, ReceiveString(server_.get(), message.length())); } } // Tests IPv4 client/server. TEST_F(SocketTest, IPv4) { ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient("127.0.0.1")); EXPECT_TRUE(SendString(client_.get(), "foo")); EXPECT_EQ("foo", ReceiveString(server_.get())); EXPECT_TRUE(SendString(server_.get(), "bar")); EXPECT_EQ("bar", ReceiveString(client_.get())); } // Tests IPv6 client/server. TEST_F(SocketTest, IPv6) { ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient("::1")); EXPECT_TRUE(SendString(client_.get(), "foo")); EXPECT_EQ("foo", ReceiveString(server_.get())); EXPECT_TRUE(SendString(server_.get(), "bar")); EXPECT_EQ("bar", ReceiveString(client_.get())); } // Tests receive timeout. The timing verification logic must be very coarse to make sure different // systems running different loads can all pass these tests. TEST_F(SocketTest, ReceiveTimeout) { time_t start_time; ASSERT_TRUE(StartServer()); // Make sure a 20ms timeout completes in 1 second or less. start_time = time(nullptr); EXPECT_TRUE(ReceiveTimeout(server_.get(), 20)); EXPECT_LE(difftime(time(nullptr), start_time), 1.0); // Make sure a 1250ms timeout takes 1 second or more. start_time = time(nullptr); EXPECT_TRUE(ReceiveTimeout(server_.get(), 1250)); EXPECT_LE(1.0, difftime(time(nullptr), start_time)); EXPECT_TRUE(SendString(client.get(), message)); EXPECT_TRUE(ReceiveString(server.get(), message)); } // Tests receive overflow (the UDP packet is larger than the receive buffer). TEST_F(SocketTest, ReceiveOverflow) { ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient()); EXPECT_TRUE(SendString(client_.get(), "1234567890")); // This behaves differently on different systems; some give us a truncated UDP packet, others // will error out and not return anything at all. std::string rx_string = ReceiveString(server_.get(), 5); // If we didn't get an error then the packet should have been truncated. if (rx_string != kReceiveStringError) { EXPECT_EQ("12345", rx_string); } } // Tests multiple clients sending to the same server. TEST_F(SocketTest, MultipleClients) { ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient()); ASSERT_TRUE(StartClient2()); // Tests UDP receive overflow when the UDP packet is larger than the receive buffer. TEST(SocketTest, TestUdpReceiveOverflow) { std::unique_ptr<Socket> server, client; ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kUdp, &server, &client)); EXPECT_TRUE(SendString(client_.get(), "client")); EXPECT_TRUE(SendString(client2_.get(), "client2")); EXPECT_TRUE(SendString(client.get(), "1234567890")); // Receive the packets and send a response for each (note that packets may be received // out-of-order). for (int i = 0; i < 2; ++i) { std::string received = ReceiveString(server_.get()); EXPECT_TRUE(SendString(server_.get(), received + " response")); // This behaves differently on different systems, either truncating the packet or returning -1. char buffer[5]; ssize_t bytes = server->Receive(buffer, 5, kTestTimeoutMs); if (bytes == 5) { EXPECT_EQ(0, memcmp(buffer, "12345", 5)); } else { EXPECT_EQ(-1, bytes); } EXPECT_EQ("client response", ReceiveString(client_.get())); EXPECT_EQ("client2 response", ReceiveString(client2_.get())); } Loading
fastboot/.clang-format +4 −1 Original line number Diff line number Diff line BasedOnStyle: Google AllowShortBlocksOnASingleLine: false AllowShortFunctionsOnASingleLine: false AllowShortFunctionsOnASingleLine: Inline ColumnLimit: 100 CommentPragmas: NOLINT:.* DerivePointerAlignment: false IndentWidth: 4 ContinuationIndentWidth: 8 ConstructorInitializerIndentWidth: 8 AccessModifierOffset: -2 PointerAlignment: Left TabWidth: 4 UseTab: Never Loading
fastboot/Android.mk +13 −9 Original line number Diff line number Diff line Loading @@ -24,7 +24,15 @@ LOCAL_C_INCLUDES := \ $(LOCAL_PATH)/../../extras/ext4_utils \ $(LOCAL_PATH)/../../extras/f2fs_utils \ LOCAL_SRC_FILES := protocol.cpp engine.cpp bootimg_utils.cpp fastboot.cpp util.cpp fs.cpp LOCAL_SRC_FILES := \ bootimg_utils.cpp \ engine.cpp \ fastboot.cpp \ fs.cpp\ protocol.cpp \ socket.cpp \ util.cpp \ LOCAL_MODULE := fastboot LOCAL_MODULE_TAGS := debug LOCAL_MODULE_HOST_OS := darwin linux windows Loading @@ -33,15 +41,15 @@ LOCAL_CFLAGS += -Wall -Wextra -Werror -Wunreachable-code LOCAL_CFLAGS += -DFASTBOOT_REVISION='"$(fastboot_version)"' LOCAL_SRC_FILES_linux := socket_unix.cpp usb_linux.cpp util_linux.cpp LOCAL_SRC_FILES_linux := usb_linux.cpp util_linux.cpp LOCAL_STATIC_LIBRARIES_linux := libselinux LOCAL_SRC_FILES_darwin := socket_unix.cpp usb_osx.cpp util_osx.cpp LOCAL_SRC_FILES_darwin := usb_osx.cpp util_osx.cpp LOCAL_STATIC_LIBRARIES_darwin := libselinux LOCAL_LDLIBS_darwin := -lpthread -framework CoreFoundation -framework IOKit -framework Carbon LOCAL_CFLAGS_darwin := -Wno-unused-parameter LOCAL_SRC_FILES_windows := socket_windows.cpp usb_windows.cpp util_windows.cpp LOCAL_SRC_FILES_windows := usb_windows.cpp util_windows.cpp LOCAL_STATIC_LIBRARIES_windows := AdbWinApi LOCAL_REQUIRED_MODULES_windows := AdbWinApi LOCAL_LDLIBS_windows := -lws2_32 Loading Loading @@ -98,18 +106,14 @@ include $(CLEAR_VARS) LOCAL_MODULE := fastboot_test LOCAL_MODULE_HOST_OS := darwin linux windows LOCAL_SRC_FILES := socket_test.cpp LOCAL_SRC_FILES := socket.cpp socket_test.cpp LOCAL_STATIC_LIBRARIES := libbase libcutils LOCAL_CFLAGS += -Wall -Wextra -Werror -Wunreachable-code LOCAL_SRC_FILES_linux := socket_unix.cpp LOCAL_SRC_FILES_darwin := socket_unix.cpp LOCAL_LDLIBS_darwin := -lpthread -framework CoreFoundation -framework IOKit -framework Carbon LOCAL_CFLAGS_darwin := -Wno-unused-parameter LOCAL_SRC_FILES_windows := socket_windows.cpp LOCAL_LDLIBS_windows := -lws2_32 include $(BUILD_HOST_NATIVE_TEST)
fastboot/socket_unix.cpp→fastboot/socket.cpp +212 −0 Original line number Diff line number Diff line Loading @@ -28,34 +28,74 @@ #include "socket.h" #include <errno.h> #include <netdb.h> #include <android-base/stringprintf.h> #include <cutils/sockets.h> class UnixUdpSocket : public UdpSocket { Socket::Socket(cutils_socket_t sock) : sock_(sock) {} Socket::~Socket() { Close(); } int Socket::Close() { int ret = 0; if (sock_ != INVALID_SOCKET) { ret = socket_close(sock_); sock_ = INVALID_SOCKET; } return ret; } bool Socket::SetReceiveTimeout(int timeout_ms) { if (timeout_ms != receive_timeout_ms_) { if (socket_set_receive_timeout(sock_, timeout_ms) == 0) { receive_timeout_ms_ = timeout_ms; return true; } return false; } return true; } ssize_t Socket::ReceiveAll(void* data, size_t length, int timeout_ms) { size_t total = 0; while (total < length) { ssize_t bytes = Receive(reinterpret_cast<char*>(data) + total, length - total, timeout_ms); if (bytes == -1) { if (total == 0) { return -1; } break; } total += bytes; } return total; } // Implements the Socket interface for UDP. class UdpSocket : public Socket { public: enum class Type { kClient, kServer }; UnixUdpSocket(int fd, Type type); ~UnixUdpSocket() override; UdpSocket(Type type, cutils_socket_t sock); ssize_t Send(const void* data, size_t length) override; ssize_t Receive(void* data, size_t length, int timeout_ms) override; int Close() override; private: int fd_; int receive_timeout_ms_ = 0; std::unique_ptr<sockaddr_storage> addr_; socklen_t addr_size_ = 0; DISALLOW_COPY_AND_ASSIGN(UnixUdpSocket); DISALLOW_COPY_AND_ASSIGN(UdpSocket); }; UnixUdpSocket::UnixUdpSocket(int fd, Type type) : fd_(fd) { // Only servers need to remember addresses; clients are connected to a server in NewUdpClient() UdpSocket::UdpSocket(Type type, cutils_socket_t sock) : Socket(sock) { // Only servers need to remember addresses; clients are connected to a server in NewClient() // so will send to that server without needing to specify the address again. if (type == Type::kServer) { addr_.reset(new sockaddr_storage); Loading @@ -64,26 +104,15 @@ UnixUdpSocket::UnixUdpSocket(int fd, Type type) : fd_(fd) { } } UnixUdpSocket::~UnixUdpSocket() { Close(); } ssize_t UnixUdpSocket::Send(const void* data, size_t length) { return TEMP_FAILURE_RETRY( sendto(fd_, data, length, 0, reinterpret_cast<sockaddr*>(addr_.get()), addr_size_)); ssize_t UdpSocket::Send(const void* data, size_t length) { return TEMP_FAILURE_RETRY(sendto(sock_, reinterpret_cast<const char*>(data), length, 0, reinterpret_cast<sockaddr*>(addr_.get()), addr_size_)); } ssize_t UnixUdpSocket::Receive(void* data, size_t length, int timeout_ms) { // Only set socket timeout if it's changed. if (receive_timeout_ms_ != timeout_ms) { timeval tv; tv.tv_sec = timeout_ms / 1000; tv.tv_usec = (timeout_ms % 1000) * 1000; if (setsockopt(fd_, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) { ssize_t UdpSocket::Receive(void* data, size_t length, int timeout_ms) { if (!SetReceiveTimeout(timeout_ms)) { return -1; } receive_timeout_ms_ = timeout_ms; } socklen_t* addr_size_ptr = nullptr; if (addr_ != nullptr) { Loading @@ -91,41 +120,93 @@ ssize_t UnixUdpSocket::Receive(void* data, size_t length, int timeout_ms) { addr_size_ = sizeof(*addr_); addr_size_ptr = &addr_size_; } return TEMP_FAILURE_RETRY(recvfrom(fd_, data, length, 0, return TEMP_FAILURE_RETRY(recvfrom(sock_, reinterpret_cast<char*>(data), length, 0, reinterpret_cast<sockaddr*>(addr_.get()), addr_size_ptr)); } int UnixUdpSocket::Close() { int result = 0; if (fd_ != -1) { result = close(fd_); fd_ = -1; // Implements the Socket interface for TCP. class TcpSocket : public Socket { public: TcpSocket(cutils_socket_t sock) : Socket(sock) {} ssize_t Send(const void* data, size_t length) override; ssize_t Receive(void* data, size_t length, int timeout_ms) override; std::unique_ptr<Socket> Accept() override; private: DISALLOW_COPY_AND_ASSIGN(TcpSocket); }; ssize_t TcpSocket::Send(const void* data, size_t length) { size_t total = 0; while (total < length) { ssize_t bytes = TEMP_FAILURE_RETRY( send(sock_, reinterpret_cast<const char*>(data) + total, length - total, 0)); if (bytes == -1) { if (total == 0) { return -1; } break; } return result; total += bytes; } std::unique_ptr<UdpSocket> UdpSocket::NewUdpClient(const std::string& host, int port, std::string* error) { int getaddrinfo_error = 0; int fd = socket_network_client_timeout(host.c_str(), port, SOCK_DGRAM, 0, &getaddrinfo_error); if (fd == -1) { if (error) { *error = android::base::StringPrintf( "Failed to connect to %s:%d: %s", host.c_str(), port, getaddrinfo_error ? gai_strerror(getaddrinfo_error) : strerror(errno)); return total; } ssize_t TcpSocket::Receive(void* data, size_t length, int timeout_ms) { if (!SetReceiveTimeout(timeout_ms)) { return -1; } return TEMP_FAILURE_RETRY(recv(sock_, reinterpret_cast<char*>(data), length, 0)); } std::unique_ptr<Socket> TcpSocket::Accept() { cutils_socket_t handler = accept(sock_, nullptr, nullptr); if (handler == INVALID_SOCKET) { return nullptr; } return std::unique_ptr<TcpSocket>(new TcpSocket(handler)); } return std::unique_ptr<UdpSocket>(new UnixUdpSocket(fd, UnixUdpSocket::Type::kClient)); std::unique_ptr<Socket> Socket::NewClient(Protocol protocol, const std::string& host, int port, std::string* error) { if (protocol == Protocol::kUdp) { cutils_socket_t sock = socket_network_client(host.c_str(), port, SOCK_DGRAM); if (sock != INVALID_SOCKET) { return std::unique_ptr<UdpSocket>(new UdpSocket(UdpSocket::Type::kClient, sock)); } } else { cutils_socket_t sock = socket_network_client(host.c_str(), port, SOCK_STREAM); if (sock != INVALID_SOCKET) { return std::unique_ptr<TcpSocket>(new TcpSocket(sock)); } } std::unique_ptr<UdpSocket> UdpSocket::NewUdpServer(int port) { int fd = socket_inaddr_any_server(port, SOCK_DGRAM); if (fd == -1) { // This is just used in testing, no need for an error message. if (error) { *error = android::base::StringPrintf("Failed to connect to %s:%d", host.c_str(), port); } return nullptr; } return std::unique_ptr<UdpSocket>(new UnixUdpSocket(fd, UnixUdpSocket::Type::kServer)); // This functionality is currently only used by tests so we don't need any error messages. std::unique_ptr<Socket> Socket::NewServer(Protocol protocol, int port) { if (protocol == Protocol::kUdp) { cutils_socket_t sock = socket_inaddr_any_server(port, SOCK_DGRAM); if (sock != INVALID_SOCKET) { return std::unique_ptr<UdpSocket>(new UdpSocket(UdpSocket::Type::kServer, sock)); } } else { cutils_socket_t sock = socket_inaddr_any_server(port, SOCK_STREAM); if (sock != INVALID_SOCKET) { return std::unique_ptr<TcpSocket>(new TcpSocket(sock)); } } return nullptr; }
fastboot/socket.h +35 −15 Original line number Diff line number Diff line Loading @@ -26,36 +26,41 @@ * SUCH DAMAGE. */ // This file provides a class interface for cross-platform UDP functionality. The main fastboot // This file provides a class interface for cross-platform socket functionality. The main fastboot // engine should not be using this interface directly, but instead should use a higher-level // interface that enforces the fastboot UDP protocol. // interface that enforces the fastboot protocol. #ifndef SOCKET_H_ #define SOCKET_H_ #include "android-base/macros.h" #include <memory> #include <string> // UdpSocket interface to be implemented for each platform. class UdpSocket { #include <android-base/macros.h> #include <cutils/sockets.h> // Socket interface to be implemented for each platform. class Socket { public: enum class Protocol { kTcp, kUdp }; // Creates a new client connection. Clients are connected to a specific hostname/port and can // only send to that destination. // On failure, |error| is filled (if non-null) and nullptr is returned. static std::unique_ptr<UdpSocket> NewUdpClient(const std::string& hostname, int port, std::string* error); static std::unique_ptr<Socket> NewClient(Protocol protocol, const std::string& hostname, int port, std::string* error); // Creates a new server bound to local |port|. This is only meant for testing, during normal // fastboot operation the device acts as the server. // The server saves sender addresses in Receive(), and uses the most recent address during // A UDP server saves sender addresses in Receive(), and uses the most recent address during // calls to Send(). static std::unique_ptr<UdpSocket> NewUdpServer(int port); static std::unique_ptr<Socket> NewServer(Protocol protocol, int port); virtual ~UdpSocket() = default; // Destructor closes the socket if it's open. virtual ~Socket(); // Sends |length| bytes of |data|. Returns the number of bytes actually sent or -1 on error. // Sends |length| bytes of |data|. For TCP sockets this will continue trying to send until all // bytes are transmitted. Returns the number of bytes actually sent or -1 on error. virtual ssize_t Send(const void* data, size_t length) = 0; // Waits up to |timeout_ms| to receive up to |length| bytes of data. |timout_ms| of 0 will Loading @@ -63,14 +68,29 @@ class UdpSocket { // errno will be set to EAGAIN or EWOULDBLOCK. virtual ssize_t Receive(void* data, size_t length, int timeout_ms) = 0; // Calls Receive() until exactly |length| bytes have been received or an error occurs. virtual ssize_t ReceiveAll(void* data, size_t length, int timeout_ms); // Closes the socket. Returns 0 on success, -1 on error. virtual int Close() = 0; virtual int Close(); // Accepts an incoming TCP connection. No effect for UDP sockets. Returns a new Socket // connected to the client on success, nullptr on failure. virtual std::unique_ptr<Socket> Accept() { return nullptr; } protected: // Protected constructor to force factory function use. UdpSocket() = default; Socket(cutils_socket_t sock); // Update the socket receive timeout if necessary. bool SetReceiveTimeout(int timeout_ms); cutils_socket_t sock_ = INVALID_SOCKET; private: int receive_timeout_ms_ = 0; DISALLOW_COPY_AND_ASSIGN(UdpSocket); DISALLOW_COPY_AND_ASSIGN(Socket); }; #endif // SOCKET_H_
fastboot/socket_test.cpp +71 −142 Original line number Diff line number Diff line Loading @@ -14,96 +14,86 @@ * limitations under the License. */ // Tests UDP functionality using loopback connections. Requires that kDefaultPort is available // Tests UDP functionality using loopback connections. Requires that kTestPort is available // for loopback communication on the host. These tests also assume that no UDP packets are lost, // which should be the case for loopback communication, but is not guaranteed. #include "socket.h" #include <errno.h> #include <time.h> #include <memory> #include <string> #include <vector> #include <gtest/gtest.h> enum { // This port must be available for loopback communication. kDefaultPort = 54321, kTestPort = 54321, // Don't wait forever in a unit test. kDefaultTimeoutMs = 3000, kTestTimeoutMs = 3000, }; static const char kReceiveStringError[] = "Error receiving string"; // Test fixture to provide some helper functions. Makes each test a little simpler since we can // just check a bool for socket creation and don't have to pass hostname or port information. class SocketTest : public ::testing::Test { protected: bool StartServer(int port = kDefaultPort) { server_ = UdpSocket::NewUdpServer(port); return server_ != nullptr; // Creates connected sockets |server| and |client|. Returns true on success. bool MakeConnectedSockets(Socket::Protocol protocol, std::unique_ptr<Socket>* server, std::unique_ptr<Socket>* client, const std::string hostname = "localhost", int port = kTestPort) { *server = Socket::NewServer(protocol, port); if (*server == nullptr) { ADD_FAILURE() << "Failed to create server."; return false; } bool StartClient(const std::string hostname = "localhost", int port = kDefaultPort) { client_ = UdpSocket::NewUdpClient(hostname, port, nullptr); return client_ != nullptr; *client = Socket::NewClient(protocol, hostname, port, nullptr); if (*client == nullptr) { ADD_FAILURE() << "Failed to create client."; return false; } bool StartClient2(const std::string hostname = "localhost", int port = kDefaultPort) { client2_ = UdpSocket::NewUdpClient(hostname, port, nullptr); return client2_ != nullptr; // TCP passes the client off to a new socket. if (protocol == Socket::Protocol::kTcp) { *server = (*server)->Accept(); if (*server == nullptr) { ADD_FAILURE() << "Failed to accept client connection."; return false; } std::unique_ptr<UdpSocket> server_, client_, client2_; }; // Sends a string over a UdpSocket. Returns true if the full string (without terminating char) // was sent. static bool SendString(UdpSocket* udp, const std::string& message) { return udp->Send(message.c_str(), message.length()) == static_cast<ssize_t>(message.length()); } // Receives a string from a UdpSocket. Returns the string, or kReceiveStringError on failure. static std::string ReceiveString(UdpSocket* udp, size_t receive_size = 128) { std::vector<char> buffer(receive_size); ssize_t result = udp->Receive(buffer.data(), buffer.size(), kDefaultTimeoutMs); if (result >= 0) { return std::string(buffer.data(), result); } return kReceiveStringError; return true; } // Calls Receive() on the UdpSocket with the given timeout. Returns true if the call timed out. static bool ReceiveTimeout(UdpSocket* udp, int timeout_ms) { char buffer[1]; // Sends a string over a Socket. Returns true if the full string (without terminating char) // was sent. static bool SendString(Socket* sock, const std::string& message) { return sock->Send(message.c_str(), message.length()) == static_cast<ssize_t>(message.length()); } errno = 0; return udp->Receive(buffer, 1, timeout_ms) == -1 && (errno == EAGAIN || errno == EWOULDBLOCK); // Receives a string from a Socket. Returns true if the full string (without terminating char) // was received. static bool ReceiveString(Socket* sock, const std::string& message) { std::string received(message.length(), '\0'); ssize_t bytes = sock->ReceiveAll(&received[0], received.length(), kTestTimeoutMs); return static_cast<size_t>(bytes) == received.length() && received == message; } // Tests sending packets client -> server, then server -> client. TEST_F(SocketTest, SendAndReceive) { ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient()); TEST(SocketTest, TestSendAndReceive) { std::unique_ptr<Socket> server, client; for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); EXPECT_TRUE(SendString(client_.get(), "foo")); EXPECT_EQ("foo", ReceiveString(server_.get())); EXPECT_TRUE(SendString(client.get(), "foo")); EXPECT_TRUE(ReceiveString(server.get(), "foo")); EXPECT_TRUE(SendString(server_.get(), "bar baz")); EXPECT_EQ("bar baz", ReceiveString(client_.get())); EXPECT_TRUE(SendString(server.get(), "bar baz")); EXPECT_TRUE(ReceiveString(client.get(), "bar baz")); } } // Tests sending and receiving large packets. TEST_F(SocketTest, LargePackets) { std::string message(512, '\0'); TEST(SocketTest, TestLargePackets) { std::string message(1024, '\0'); std::unique_ptr<Socket> server, client; ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient()); for (Socket::Protocol protocol : {Socket::Protocol::kUdp, Socket::Protocol::kTcp}) { ASSERT_TRUE(MakeConnectedSockets(protocol, &server, &client)); // Run through the test a few times. for (int i = 0; i < 10; ++i) { Loading @@ -112,86 +102,25 @@ TEST_F(SocketTest, LargePackets) { message[j] = static_cast<char>(i + j); } EXPECT_TRUE(SendString(client_.get(), message)); EXPECT_EQ(message, ReceiveString(server_.get(), message.length())); } } // Tests IPv4 client/server. TEST_F(SocketTest, IPv4) { ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient("127.0.0.1")); EXPECT_TRUE(SendString(client_.get(), "foo")); EXPECT_EQ("foo", ReceiveString(server_.get())); EXPECT_TRUE(SendString(server_.get(), "bar")); EXPECT_EQ("bar", ReceiveString(client_.get())); } // Tests IPv6 client/server. TEST_F(SocketTest, IPv6) { ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient("::1")); EXPECT_TRUE(SendString(client_.get(), "foo")); EXPECT_EQ("foo", ReceiveString(server_.get())); EXPECT_TRUE(SendString(server_.get(), "bar")); EXPECT_EQ("bar", ReceiveString(client_.get())); } // Tests receive timeout. The timing verification logic must be very coarse to make sure different // systems running different loads can all pass these tests. TEST_F(SocketTest, ReceiveTimeout) { time_t start_time; ASSERT_TRUE(StartServer()); // Make sure a 20ms timeout completes in 1 second or less. start_time = time(nullptr); EXPECT_TRUE(ReceiveTimeout(server_.get(), 20)); EXPECT_LE(difftime(time(nullptr), start_time), 1.0); // Make sure a 1250ms timeout takes 1 second or more. start_time = time(nullptr); EXPECT_TRUE(ReceiveTimeout(server_.get(), 1250)); EXPECT_LE(1.0, difftime(time(nullptr), start_time)); EXPECT_TRUE(SendString(client.get(), message)); EXPECT_TRUE(ReceiveString(server.get(), message)); } // Tests receive overflow (the UDP packet is larger than the receive buffer). TEST_F(SocketTest, ReceiveOverflow) { ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient()); EXPECT_TRUE(SendString(client_.get(), "1234567890")); // This behaves differently on different systems; some give us a truncated UDP packet, others // will error out and not return anything at all. std::string rx_string = ReceiveString(server_.get(), 5); // If we didn't get an error then the packet should have been truncated. if (rx_string != kReceiveStringError) { EXPECT_EQ("12345", rx_string); } } // Tests multiple clients sending to the same server. TEST_F(SocketTest, MultipleClients) { ASSERT_TRUE(StartServer()); ASSERT_TRUE(StartClient()); ASSERT_TRUE(StartClient2()); // Tests UDP receive overflow when the UDP packet is larger than the receive buffer. TEST(SocketTest, TestUdpReceiveOverflow) { std::unique_ptr<Socket> server, client; ASSERT_TRUE(MakeConnectedSockets(Socket::Protocol::kUdp, &server, &client)); EXPECT_TRUE(SendString(client_.get(), "client")); EXPECT_TRUE(SendString(client2_.get(), "client2")); EXPECT_TRUE(SendString(client.get(), "1234567890")); // Receive the packets and send a response for each (note that packets may be received // out-of-order). for (int i = 0; i < 2; ++i) { std::string received = ReceiveString(server_.get()); EXPECT_TRUE(SendString(server_.get(), received + " response")); // This behaves differently on different systems, either truncating the packet or returning -1. char buffer[5]; ssize_t bytes = server->Receive(buffer, 5, kTestTimeoutMs); if (bytes == 5) { EXPECT_EQ(0, memcmp(buffer, "12345", 5)); } else { EXPECT_EQ(-1, bytes); } EXPECT_EQ("client response", ReceiveString(client_.get())); EXPECT_EQ("client2 response", ReceiveString(client2_.get())); }