Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 7d1279b3 authored by Chia-chi Yeh's avatar Chia-chi Yeh Committed by Android (Google) Code Review
Browse files

Merge "VPN: refactor few JNI methods for the usage of legacy VPN."

parents e58e9b8c f4e3bf89
Loading
Loading
Loading
Loading
+23 −16
Original line number Diff line number Diff line
@@ -90,7 +90,7 @@ public class Vpn extends INetworkManagementEventObserver.Stub {

        // Reset the interface and hide the notification.
        if (mInterfaceName != null) {
            nativeReset(mInterfaceName);
            jniResetInterface(mInterfaceName);
            mCallback.restore();
            hideNotification();
            mInterfaceName = null;
@@ -119,7 +119,7 @@ public class Vpn extends INetworkManagementEventObserver.Stub {
    public void protect(ParcelFileDescriptor socket, String name) {
        try {
            mContext.enforceCallingPermission(VPN, "protect");
            nativeProtect(socket.getFd(), name);
            jniProtectSocket(socket.getFd(), name);
        } finally {
            try {
                socket.close();
@@ -152,17 +152,22 @@ public class Vpn extends INetworkManagementEventObserver.Stub {
        }

        // Create and configure the interface.
        ParcelFileDescriptor descriptor = ParcelFileDescriptor.adoptFd(
                nativeEstablish(config.mtu, config.addresses, config.routes));
        ParcelFileDescriptor descriptor =
                ParcelFileDescriptor.adoptFd(jniCreateInterface(config.mtu));

        // Replace the interface and abort if it fails.
        // Abort if any of the following steps fails.
        try {
            String interfaceName = nativeGetName(descriptor.getFd());

            if (mInterfaceName != null && !mInterfaceName.equals(interfaceName)) {
                nativeReset(mInterfaceName);
            String name = jniGetInterfaceName(descriptor.getFd());
            if (jniSetAddresses(name, config.addresses) < 1) {
                throw new IllegalArgumentException("At least one address must be specified");
            }
            if (config.routes != null) {
                jniSetRoutes(name, config.routes);
            }
            if (mInterfaceName != null && !mInterfaceName.equals(name)) {
                jniResetInterface(mInterfaceName);
            }
            mInterfaceName = interfaceName;
            mInterfaceName = name;
        } catch (RuntimeException e) {
            try {
                descriptor.close();
@@ -195,7 +200,7 @@ public class Vpn extends INetworkManagementEventObserver.Stub {

    // INetworkManagementEventObserver.Stub
    public synchronized void interfaceRemoved(String name) {
        if (name.equals(mInterfaceName) && nativeCheck(name) == 0) {
        if (name.equals(mInterfaceName) && jniCheckInterface(name) == 0) {
            hideNotification();
            mInterfaceName = null;
            mCallback.restore();
@@ -253,11 +258,13 @@ public class Vpn extends INetworkManagementEventObserver.Stub {
        }
    }

    private native int nativeEstablish(int mtu, String addresses, String routes);
    private native String nativeGetName(int fd);
    private native void nativeReset(String name);
    private native int nativeCheck(String name);
    private native void nativeProtect(int fd, String name);
    private native int jniCreateInterface(int mtu);
    private native String jniGetInterfaceName(int fd);
    private native int jniSetAddresses(String name, String addresses);
    private native int jniSetRoutes(String name, String routes);
    private native void jniResetInterface(String name);
    private native int jniCheckInterface(String name);
    private native void jniProtectSocket(int fd, String name);

    /**
     * Handle legacy VPN requests. This method stops the services and restart
+147 −114
Original line number Diff line number Diff line
@@ -42,6 +42,9 @@
namespace android
{

static int inet4 = -1;
static int inet6 = -1;

static inline in_addr_t *as_in_addr(sockaddr *sa) {
    return &((sockaddr_in *)sa)->sin_addr.s_addr;
}
@@ -51,11 +54,9 @@ static inline in_addr_t *as_in_addr(sockaddr *sa) {
#define SYSTEM_ERROR -1
#define BAD_ARGUMENT -2

static int create_interface(int mtu, char *name, int *index)
static int create_interface(int mtu)
{
    int tun = open("/dev/tun", O_RDWR);
    int inet4 = socket(AF_INET, SOCK_DGRAM, 0);
    int flags;
    int tun = open("/dev/tun", O_RDWR | O_NONBLOCK);

    ifreq ifr4;
    memset(&ifr4, 0, sizeof(ifr4));
@@ -81,38 +82,45 @@ static int create_interface(int mtu, char *name, int *index)
        goto error;
    }

    // Get interface index.
    if (ioctl(inet4, SIOGIFINDEX, &ifr4)) {
        LOGE("Cannot get index of %s: %s", ifr4.ifr_name, strerror(errno));
        goto error;
    }

    // Make it non-blocking.
    flags = fcntl(tun, F_GETFL, 0);
    if (flags == -1 || fcntl(tun, F_SETFL, flags | O_NONBLOCK)) {
        LOGE("Cannot set non-blocking on %s: %s", ifr4.ifr_name, strerror(errno));
        goto error;
    }

    strcpy(name, ifr4.ifr_name);
    *index = ifr4.ifr_ifindex;
    close(inet4);
    return tun;

error:
    close(tun);
    close(inet4);
    return SYSTEM_ERROR;
}

static int set_addresses(const char *name, int index, const char *addresses)
static int get_interface_name(char *name, int tun)
{
    int inet4 = socket(AF_INET, SOCK_DGRAM, 0);
    int inet6 = socket(AF_INET6, SOCK_DGRAM, 0);
    ifreq ifr4;
    if (ioctl(tun, TUNGETIFF, &ifr4)) {
        LOGE("Cannot get interface name: %s", strerror(errno));
        return SYSTEM_ERROR;
    }
    strncpy(name, ifr4.ifr_name, IFNAMSIZ);
    return 0;
}

static int get_interface_index(const char *name)
{
    ifreq ifr4;
    strncpy(ifr4.ifr_name, name, IFNAMSIZ);
    if (ioctl(inet4, SIOGIFINDEX, &ifr4)) {
        LOGE("Cannot get index of %s: %s", name, strerror(errno));
        return SYSTEM_ERROR;
    }
    return ifr4.ifr_ifindex;
}

static int set_addresses(const char *name, const char *addresses)
{
    int index = get_interface_index(name);
    if (index < 0) {
        return index;
    }

    ifreq ifr4;
    memset(&ifr4, 0, sizeof(ifr4));
    strcpy(ifr4.ifr_name, name);
    strncpy(ifr4.ifr_name, name, IFNAMSIZ);
    ifr4.ifr_addr.sa_family = AF_INET;

    in6_ifreq ifr6;
@@ -121,7 +129,6 @@ static int set_addresses(const char *name, int index, const char *addresses)

    char address[65];
    int prefix;

    int chars;
    int count = 0;

@@ -164,7 +171,7 @@ static int set_addresses(const char *name, int index, const char *addresses)
                break;
            }
        }
        LOGV("Address added on %s: %s/%d", name, address, prefix);
        LOGD("Address added on %s: %s/%d", name, address, prefix);
        ++count;
    }

@@ -177,15 +184,15 @@ static int set_addresses(const char *name, int index, const char *addresses)
        count = BAD_ARGUMENT;
    }

    close(inet4);
    close(inet6);
    return count;
}

static int set_routes(const char *name, int index, const char *routes)
static int set_routes(const char *name, const char *routes)
{
    int inet4 = socket(AF_INET, SOCK_DGRAM, 0);
    int inet6 = socket(AF_INET6, SOCK_DGRAM, 0);
    int index = get_interface_index(name);
    if (index < 0) {
        return index;
    }

    rtentry rt4;
    memset(&rt4, 0, sizeof(rt4));
@@ -201,7 +208,6 @@ static int set_routes(const char *name, int index, const char *routes)

    char address[65];
    int prefix;

    int chars;
    int count = 0;

@@ -211,32 +217,50 @@ static int set_routes(const char *name, int index, const char *routes)
        if (strchr(address, ':')) {
            // Add an IPv6 route.
            if (inet_pton(AF_INET6, address, &rt6.rtmsg_dst) != 1 ||
                    prefix < 1 || prefix > 128) {
                    prefix < 0 || prefix > 128) {
                count = BAD_ARGUMENT;
                break;
            }

            rt6.rtmsg_dst_len = prefix;
            rt6.rtmsg_dst_len = prefix ? prefix : 1;
            if (ioctl(inet6, SIOCADDRT, &rt6) && errno != EEXIST) {
                count = (errno == EINVAL) ? BAD_ARGUMENT : SYSTEM_ERROR;
                break;
            }

            if (!prefix) {
                // Split the route instead of replacing the default route.
                rt6.rtmsg_dst.s6_addr[0] ^= 0x80;
                if (ioctl(inet6, SIOCADDRT, &rt6) && errno != EEXIST) {
                    count = SYSTEM_ERROR;
                    break;
                }
            }
        } else {
            // Add an IPv4 route.
            if (inet_pton(AF_INET, address, as_in_addr(&rt4.rt_dst)) != 1 ||
                    prefix < 1 || prefix > 32) {
                    prefix < 0 || prefix > 32) {
                count = BAD_ARGUMENT;
                break;
            }

            in_addr_t mask = prefix ? (~0 << (32 - prefix)) : 0;
            in_addr_t mask = prefix ? (~0 << (32 - prefix)) : 1;
            *as_in_addr(&rt4.rt_genmask) = htonl(mask);
            if (ioctl(inet4, SIOCADDRT, &rt4) && errno != EEXIST) {
                count = (errno == EINVAL) ? BAD_ARGUMENT : SYSTEM_ERROR;
                break;
            }

            if (!prefix) {
                // Split the route instead of replacing the default route.
                *as_in_addr(&rt4.rt_dst) ^= htonl(0x80000000);
                if (ioctl(inet4, SIOCADDRT, &rt4) && errno != EEXIST) {
                    count = SYSTEM_ERROR;
                    break;
                }
            }
        LOGV("Route added on %s: %s/%d", name, address, prefix);
        }
        LOGD("Route added on %s: %s/%d", name, address, prefix);
        ++count;
    }

@@ -250,43 +274,24 @@ static int set_routes(const char *name, int index, const char *routes)
        count = BAD_ARGUMENT;
    }

    close(inet4);
    close(inet6);
    return count;
}

static int get_interface_name(char *name, int tun)
{
    ifreq ifr4;
    if (ioctl(tun, TUNGETIFF, &ifr4)) {
        LOGE("Cannot get interface name: %s", strerror(errno));
        return SYSTEM_ERROR;
    }
    strcpy(name, ifr4.ifr_name);
    return 0;
}

static int reset_interface(const char *name)
{
    int inet4 = socket(AF_INET, SOCK_DGRAM, 0);

    ifreq ifr4;
    ifr4.ifr_flags = 0;
    strncpy(ifr4.ifr_name, name, IFNAMSIZ);
    ifr4.ifr_flags = 0;

    if (ioctl(inet4, SIOCSIFFLAGS, &ifr4) && errno != ENODEV) {
        LOGE("Cannot reset %s: %s", name, strerror(errno));
        close(inet4);
        return SYSTEM_ERROR;
    }
    close(inet4);
    return 0;
}

static int check_interface(const char *name)
{
    int inet4 = socket(AF_INET, SOCK_DGRAM, 0);

    ifreq ifr4;
    strncpy(ifr4.ifr_name, name, IFNAMSIZ);
    ifr4.ifr_flags = 0;
@@ -294,7 +299,6 @@ static int check_interface(const char *name)
    if (ioctl(inet4, SIOCGIFFLAGS, &ifr4) && errno != ENODEV) {
        LOGE("Cannot check %s: %s", name, strerror(errno));
    }
    close(inet4);
    return ifr4.ifr_flags;
}

@@ -318,86 +322,108 @@ static void throwException(JNIEnv *env, int error, const char *message)
    }
}

static jint establish(JNIEnv *env, jobject thiz,
        jint mtu, jstring jAddresses, jstring jRoutes)
static jint createInterface(JNIEnv *env, jobject thiz, jint mtu)
{
    char name[IFNAMSIZ];
    int index;
    int tun = create_interface(mtu, name, &index);
    int tun = create_interface(mtu);
    if (tun < 0) {
        throwException(env, tun, "Cannot create interface");
        return -1;
    }
    LOGD("%s is created", name);
    return tun;
}

static jstring getInterfaceName(JNIEnv *env, jobject thiz, jint tun)
{
    char name[IFNAMSIZ];
    if (get_interface_name(name, tun) < 0) {
        throwException(env, SYSTEM_ERROR, "Cannot get interface name");
        return NULL;
    }
    return env->NewStringUTF(name);
}

    const char *addresses;
    const char *routes;
    int count;
static jint setAddresses(JNIEnv *env, jobject thiz, jstring jName,
        jstring jAddresses)
{
    const char *name = NULL;
    const char *addresses = NULL;
    int count = -1;

    // Addresses are required.
    name = jName ? env->GetStringUTFChars(jName, NULL) : NULL;
    if (!name) {
        jniThrowNullPointerException(env, "name");
        goto error;
    }
    addresses = jAddresses ? env->GetStringUTFChars(jAddresses, NULL) : NULL;
    if (!addresses) {
        jniThrowNullPointerException(env, "address");
        jniThrowNullPointerException(env, "addresses");
        goto error;
    }
    count = set_addresses(name, index, addresses);
    env->ReleaseStringUTFChars(jAddresses, addresses);
    if (count <= 0) {
    count = set_addresses(name, addresses);
    if (count < 0) {
        throwException(env, count, "Cannot set address");
        goto error;
        count = -1;
    }

error:
    if (name) {
        env->ReleaseStringUTFChars(jName, name);
    }
    if (addresses) {
        env->ReleaseStringUTFChars(jAddresses, addresses);
    }
    return count;
}
    LOGD("Configured %d address(es) on %s", count, name);

    // Routes are optional.
static jint setRoutes(JNIEnv *env, jobject thiz, jstring jName,
        jstring jRoutes)
{
    const char *name = NULL;
    const char *routes = NULL;
    int count = -1;

    name = jName ? env->GetStringUTFChars(jName, NULL) : NULL;
    if (!name) {
        jniThrowNullPointerException(env, "name");
        goto error;
    }
    routes = jRoutes ? env->GetStringUTFChars(jRoutes, NULL) : NULL;
    if (routes) {
        count = set_routes(name, index, routes);
        env->ReleaseStringUTFChars(jRoutes, routes);
        if (count < 0) {
            throwException(env, count, "Cannot set route");
    if (!routes) {
        jniThrowNullPointerException(env, "routes");
        goto error;
    }
        LOGD("Configured %d route(s) on %s", count, name);
    count = set_routes(name, routes);
    if (count < 0) {
        throwException(env, count, "Cannot set address");
        count = -1;
    }

    return tun;

error:
    close(tun);
    LOGD("%s is destroyed", name);
    return -1;
    if (name) {
        env->ReleaseStringUTFChars(jName, name);
    }

static jstring getName(JNIEnv *env, jobject thiz, jint fd)
{
    char name[IFNAMSIZ];
    if (get_interface_name(name, fd) < 0) {
        throwException(env, SYSTEM_ERROR, "Cannot get interface name");
        return NULL;
    if (routes) {
        env->ReleaseStringUTFChars(jRoutes, routes);
    }
    return env->NewStringUTF(name);
    return count;
}

static void reset(JNIEnv *env, jobject thiz, jstring jName)
static void resetInterface(JNIEnv *env, jobject thiz, jstring jName)
{
    const char *name = jName ?
            env->GetStringUTFChars(jName, NULL) : NULL;
    const char *name = jName ? env->GetStringUTFChars(jName, NULL) : NULL;
    if (!name) {
        jniThrowNullPointerException(env, "name");
        return;
    }
    if (reset_interface(name) < 0) {
        throwException(env, SYSTEM_ERROR, "Cannot reset interface");
    } else {
        LOGD("%s is deactivated", name);
    }
    env->ReleaseStringUTFChars(jName, name);
}

static jint check(JNIEnv *env, jobject thiz, jstring jName)
static jint checkInterface(JNIEnv *env, jobject thiz, jstring jName)
{
    const char *name = jName ?
            env->GetStringUTFChars(jName, NULL) : NULL;
    const char *name = jName ? env->GetStringUTFChars(jName, NULL) : NULL;
    if (!name) {
        jniThrowNullPointerException(env, "name");
        return 0;
@@ -407,10 +433,9 @@ static jint check(JNIEnv *env, jobject thiz, jstring jName)
    return flags;
}

static void protect(JNIEnv *env, jobject thiz, jint fd, jstring jName)
static void protectSocket(JNIEnv *env, jobject thiz, jint fd, jstring jName)
{
    const char *name = jName ?
            env->GetStringUTFChars(jName, NULL) : NULL;
    const char *name = jName ? env->GetStringUTFChars(jName, NULL) : NULL;
    if (!name) {
        jniThrowNullPointerException(env, "name");
        return;
@@ -424,15 +449,23 @@ static void protect(JNIEnv *env, jobject thiz, jint fd, jstring jName)
//------------------------------------------------------------------------------

static JNINativeMethod gMethods[] = {
    {"nativeEstablish", "(ILjava/lang/String;Ljava/lang/String;)I", (void *)establish},
    {"nativeGetName", "(I)Ljava/lang/String;", (void *)getName},
    {"nativeReset", "(Ljava/lang/String;)V", (void *)reset},
    {"nativeCheck", "(Ljava/lang/String;)I", (void *)check},
    {"nativeProtect", "(ILjava/lang/String;)V", (void *)protect},
    {"jniCreateInterface", "(I)I", (void *)createInterface},
    {"jniGetInterfaceName", "(I)Ljava/lang/String;", (void *)getInterfaceName},
    {"jniSetAddresses", "(Ljava/lang/String;Ljava/lang/String;)I", (void *)setAddresses},
    {"jniSetRoutes", "(Ljava/lang/String;Ljava/lang/String;)I", (void *)setRoutes},
    {"jniResetInterface", "(Ljava/lang/String;)V", (void *)resetInterface},
    {"jniCheckInterface", "(Ljava/lang/String;)I", (void *)checkInterface},
    {"jniProtectSocket", "(ILjava/lang/String;)V", (void *)protectSocket},
};

int register_android_server_connectivity_Vpn(JNIEnv *env)
{
    if (inet4 == -1) {
        inet4 = socket(AF_INET, SOCK_DGRAM, 0);
    }
    if (inet6 == -1) {
        inet6 = socket(AF_INET6, SOCK_DGRAM, 0);
    }
    return jniRegisterNativeMethods(env, "com/android/server/connectivity/Vpn",
            gMethods, NELEM(gMethods));
}