Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 2b49a7df authored by Erik Kline's avatar Erik Kline Committed by Android (Google) Code Review
Browse files

Merge "Incorrect time used in some NTP server responses"

parents cb45c10b 32d52f34
Loading
Loading
Loading
Loading
+78 −17
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ import android.util.Log;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.util.Arrays;

/**
 * {@hide}
@@ -38,6 +39,7 @@ import java.net.InetAddress;
public class SntpClient
{
    private static final String TAG = "SntpClient";
    private static final boolean DBG = true;

    private static final int REFERENCE_TIME_OFFSET = 16;
    private static final int ORIGINATE_TIME_OFFSET = 24;
@@ -47,8 +49,14 @@ public class SntpClient

    private static final int NTP_PORT = 123;
    private static final int NTP_MODE_CLIENT = 3;
    private static final int NTP_MODE_SERVER = 4;
    private static final int NTP_MODE_BROADCAST = 5;
    private static final int NTP_VERSION = 3;

    private static final int NTP_LEAP_NOSYNC = 3;
    private static final int NTP_STRATUM_DEATH = 0;
    private static final int NTP_STRATUM_MAX = 15;

    // Number of seconds between Jan 1, 1900 and Jan 1, 1970
    // 70 years plus 17 leap days
    private static final long OFFSET_1900_TO_1970 = ((365L * 70L) + 17L) * 24L * 60L * 60L;
@@ -62,6 +70,12 @@ public class SntpClient
    // round trip time in milliseconds
    private long mRoundTripTime;

    private static class InvalidServerReplyException extends Exception {
        public InvalidServerReplyException(String message) {
            super(message);
        }
    }

    /**
     * Sends an SNTP request to the given host and processes the response.
     *
@@ -70,13 +84,23 @@ public class SntpClient
     * @return true if the transaction was successful.
     */
    public boolean requestTime(String host, int timeout) {
        InetAddress address = null;
        try {
            address = InetAddress.getByName(host);
        } catch (Exception e) {
            if (DBG) Log.d(TAG, "request time failed: " + e);
            return false;
        }
        return requestTime(address, NTP_PORT, timeout);
    }

    public boolean requestTime(InetAddress address, int port, int timeout) {
        DatagramSocket socket = null;
        try {
            socket = new DatagramSocket();
            socket.setSoTimeout(timeout);
            InetAddress address = InetAddress.getByName(host);
            byte[] buffer = new byte[NTP_PACKET_SIZE];
            DatagramPacket request = new DatagramPacket(buffer, buffer.length, address, NTP_PORT);
            DatagramPacket request = new DatagramPacket(buffer, buffer.length, address, port);

            // set mode = 3 (client) and version = 3
            // mode is in low 3 bits of first byte
@@ -84,8 +108,8 @@ public class SntpClient
            buffer[0] = NTP_MODE_CLIENT | (NTP_VERSION << 3);

            // get current time and write it to the request packet
            long requestTime = System.currentTimeMillis();
            long requestTicks = SystemClock.elapsedRealtime();
            final long requestTime = System.currentTimeMillis();
            final long requestTicks = SystemClock.elapsedRealtime();
            writeTimeStamp(buffer, TRANSMIT_TIME_OFFSET, requestTime);

            socket.send(request);
@@ -93,13 +117,21 @@ public class SntpClient
            // read the response
            DatagramPacket response = new DatagramPacket(buffer, buffer.length);
            socket.receive(response);
            long responseTicks = SystemClock.elapsedRealtime();
            long responseTime = requestTime + (responseTicks - requestTicks);
            final long responseTicks = SystemClock.elapsedRealtime();
            final long responseTime = requestTime + (responseTicks - requestTicks);

            // extract the results
            long originateTime = readTimeStamp(buffer, ORIGINATE_TIME_OFFSET);
            long receiveTime = readTimeStamp(buffer, RECEIVE_TIME_OFFSET);
            long transmitTime = readTimeStamp(buffer, TRANSMIT_TIME_OFFSET);
            final byte leap = (byte) ((buffer[0] >> 6) & 0x3);
            final byte mode = (byte) (buffer[0] & 0x7);
            final int stratum = (int) (buffer[1] & 0xff);
            final long originateTime = readTimeStamp(buffer, ORIGINATE_TIME_OFFSET);
            final long receiveTime = readTimeStamp(buffer, RECEIVE_TIME_OFFSET);
            final long transmitTime = readTimeStamp(buffer, TRANSMIT_TIME_OFFSET);

            /* do sanity check according to RFC */
            // TODO: validate originateTime == requestTime.
            checkValidServerReply(leap, mode, stratum, transmitTime);

            long roundTripTime = responseTicks - requestTicks - (transmitTime - receiveTime);
            // receiveTime = originateTime + transit + skew
            // responseTime = transmitTime + transit - skew
@@ -110,8 +142,10 @@ public class SntpClient
            //             = (transit + skew - transit + skew)/2
            //             = (2 * skew)/2 = skew
            long clockOffset = ((receiveTime - originateTime) + (transmitTime - responseTime))/2;
            // if (false) Log.d(TAG, "round trip: " + roundTripTime + " ms");
            // if (false) Log.d(TAG, "clock offset: " + clockOffset + " ms");
            if (DBG) {
                Log.d(TAG, "round trip: " + roundTripTime + "ms, " +
                        "clock offset: " + clockOffset + "ms");
            }

            // save our results - use the times on this side of the network latency
            // (response rather than request time)
@@ -119,7 +153,7 @@ public class SntpClient
            mNtpTimeReference = responseTicks;
            mRoundTripTime = roundTripTime;
        } catch (Exception e) {
            if (false) Log.d(TAG, "request time failed: " + e);
            if (DBG) Log.d(TAG, "request time failed: " + e);
            return false;
        } finally {
            if (socket != null) {
@@ -158,6 +192,23 @@ public class SntpClient
        return mRoundTripTime;
    }

    private static void checkValidServerReply(
            byte leap, byte mode, int stratum, long transmitTime)
            throws InvalidServerReplyException {
        if (leap == NTP_LEAP_NOSYNC) {
            throw new InvalidServerReplyException("unsynchronized server");
        }
        if ((mode != NTP_MODE_SERVER) && (mode != NTP_MODE_BROADCAST)) {
            throw new InvalidServerReplyException("untrusted mode: " + mode);
        }
        if ((stratum == NTP_STRATUM_DEATH) || (stratum > NTP_STRATUM_MAX)) {
            throw new InvalidServerReplyException("untrusted stratum: " + stratum);
        }
        if (transmitTime == 0) {
            throw new InvalidServerReplyException("zero transmitTime");
        }
    }

    /**
     * Reads an unsigned 32 bit big endian number from the given offset in the buffer.
     */
@@ -183,6 +234,10 @@ public class SntpClient
    private long readTimeStamp(byte[] buffer, int offset) {
        long seconds = read32(buffer, offset);
        long fraction = read32(buffer, offset + 4);
        // Special case: zero means zero.
        if (seconds == 0 && fraction == 0) {
            return 0;
        }
        return ((seconds - OFFSET_1900_TO_1970) * 1000) + ((fraction * 1000L) / 0x100000000L);
    }

@@ -191,6 +246,12 @@ public class SntpClient
     * at the given offset in the buffer.
     */
    private void writeTimeStamp(byte[] buffer, int offset, long time) {
        // Special case: zero means zero.
        if (time == 0) {
            Arrays.fill(buffer, offset, offset + 8, (byte) 0x00);
            return;
        }

        long seconds = time / 1000L;
        long milliseconds = time - seconds * 1000L;
        seconds += OFFSET_1900_TO_1970;
+222 −0
Original line number Diff line number Diff line
/*
 * Copyright (C) 2015 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package android.net;

import android.net.SntpClient;
import android.util.Log;
import libcore.util.HexEncoding;

import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.net.InetAddress;
import java.net.SocketException;
import java.util.Arrays;
import junit.framework.TestCase;


public class SntpClientTest extends TestCase {
    private static final String TAG = "SntpClientTest";

    private static final int ORIGINATE_TIME_OFFSET = 24;
    private static final int TRANSMIT_TIME_OFFSET = 40;

    private static final int NTP_MODE_SERVER = 4;
    private static final int NTP_MODE_BROADCAST = 5;

    // From tcpdump (admittedly, an NTPv4 packet):
    //
    // Server, Leap indicator:  (0), Stratum 2 (secondary reference), poll 6 (64s), precision -20
    // Root Delay: 0.005447, Root dispersion: 0.002716, Reference-ID: 221.253.71.41
    //   Reference Timestamp:  3653932102.507969856 (2015/10/15 14:08:22)
    //   Originator Timestamp: 3653932113.576327741 (2015/10/15 14:08:33)
    //   Receive Timestamp:    3653932113.581012725 (2015/10/15 14:08:33)
    //   Transmit Timestamp:   3653932113.581012725 (2015/10/15 14:08:33)
    //     Originator - Receive Timestamp:  +0.004684958
    //     Originator - Transmit Timestamp: +0.004684958
    private static final String WORKING_VERSION4 =
            "240206ec" +
            "00000165" +
            "000000b2" +
            "ddfd4729" +
            "d9ca9446820a5000" +
            "d9ca9451938a3771" +
            "d9ca945194bd3fff" +
            "d9ca945194bd4001";

    private final SntpTestServer mServer = new SntpTestServer();
    private final SntpClient mClient = new SntpClient();

    public void testBasicWorkingSntpClientQuery() throws Exception {
        mServer.setServerReply(HexEncoding.decode(WORKING_VERSION4.toCharArray(), false));
        assertTrue(mClient.requestTime(mServer.getAddress(), mServer.getPort(), 500));
        assertEquals(1, mServer.numRequestsReceived());
        assertEquals(1, mServer.numRepliesSent());
    }

    public void testDnsResolutionFailure() throws Exception {
        assertFalse(mClient.requestTime("ntp.server.doesnotexist.example", 5000));
    }

    public void testTimeoutFailure() throws Exception {
        mServer.clearServerReply();
        assertFalse(mClient.requestTime(mServer.getAddress(), mServer.getPort(), 500));
        assertEquals(1, mServer.numRequestsReceived());
        assertEquals(0, mServer.numRepliesSent());
    }

    public void testIgnoreLeapNoSync() throws Exception {
        final byte[] reply = HexEncoding.decode(WORKING_VERSION4.toCharArray(), false);
        reply[0] |= (byte) 0xc0;
        mServer.setServerReply(reply);
        assertFalse(mClient.requestTime(mServer.getAddress(), mServer.getPort(), 500));
        assertEquals(1, mServer.numRequestsReceived());
        assertEquals(1, mServer.numRepliesSent());
    }

    public void testAcceptOnlyServerAndBroadcastModes() throws Exception {
        final byte[] reply = HexEncoding.decode(WORKING_VERSION4.toCharArray(), false);
        for (int i = 0; i <= 7; i++) {
            final String logMsg = "mode: " + i;
            reply[0] &= (byte) 0xf8;
            reply[0] |= (byte) i;
            mServer.setServerReply(reply);
            final boolean rval = mClient.requestTime(mServer.getAddress(), mServer.getPort(), 500);
            switch (i) {
                case NTP_MODE_SERVER:
                case NTP_MODE_BROADCAST:
                    assertTrue(logMsg, rval);
                    break;
                default:
                    assertFalse(logMsg, rval);
                    break;
            }
            assertEquals(logMsg, 1, mServer.numRequestsReceived());
            assertEquals(logMsg, 1, mServer.numRepliesSent());
        }
    }

    public void testAcceptableStrataOnly() throws Exception {
        final int STRATUM_MIN = 1;
        final int STRATUM_MAX = 15;

        final byte[] reply = HexEncoding.decode(WORKING_VERSION4.toCharArray(), false);
        for (int i = 0; i < 256; i++) {
            final String logMsg = "stratum: " + i;
            reply[1] = (byte) i;
            mServer.setServerReply(reply);
            final boolean rval = mClient.requestTime(mServer.getAddress(), mServer.getPort(), 500);
            if (STRATUM_MIN <= i && i <= STRATUM_MAX) {
                assertTrue(logMsg, rval);
            } else {
                assertFalse(logMsg, rval);
            }
            assertEquals(logMsg, 1, mServer.numRequestsReceived());
            assertEquals(logMsg, 1, mServer.numRepliesSent());
        }
    }

    public void testZeroTransmitTime() throws Exception {
        final byte[] reply = HexEncoding.decode(WORKING_VERSION4.toCharArray(), false);
        Arrays.fill(reply, TRANSMIT_TIME_OFFSET, TRANSMIT_TIME_OFFSET + 8, (byte) 0x00);
        mServer.setServerReply(reply);
        assertFalse(mClient.requestTime(mServer.getAddress(), mServer.getPort(), 500));
        assertEquals(1, mServer.numRequestsReceived());
        assertEquals(1, mServer.numRepliesSent());
    }


    private static class SntpTestServer {
        private final Object mLock = new Object();
        private final DatagramSocket mSocket;
        private final InetAddress mAddress;
        private final int mPort;
        private byte[] mReply;
        private int mRcvd;
        private int mSent;
        private Thread mListeningThread;

        public SntpTestServer() {
            mSocket = makeSocket();
            mAddress = mSocket.getLocalAddress();
            mPort = mSocket.getLocalPort();
            Log.d(TAG, "testing server listening on (" + mAddress + ", " + mPort + ")");

            mListeningThread = new Thread() {
                public void run() {
                    while (true) {
                        byte[] buffer = new byte[512];
                        DatagramPacket ntpMsg = new DatagramPacket(buffer, buffer.length);
                        try {
                            mSocket.receive(ntpMsg);
                        } catch (IOException e) {
                            Log.e(TAG, "datagram receive error: " + e);
                            break;
                        }
                        synchronized (mLock) {
                            mRcvd++;
                            if (mReply == null) { continue; }
                            // Copy transmit timestamp into originate timestamp.
                            // TODO: bounds checking.
                            System.arraycopy(ntpMsg.getData(), TRANSMIT_TIME_OFFSET,
                                             mReply, ORIGINATE_TIME_OFFSET, 8);
                            ntpMsg.setData(mReply);
                            ntpMsg.setLength(mReply.length);
                            try {
                                mSocket.send(ntpMsg);
                            } catch (IOException e) {
                                Log.e(TAG, "datagram send error: " + e);
                                break;
                            }
                            mSent++;
                        }
                    }
                    mSocket.close();
                }
            };
            mListeningThread.start();
        }

        private DatagramSocket makeSocket() {
            DatagramSocket socket;
            try {
                socket = new DatagramSocket(0, InetAddress.getLoopbackAddress());
            } catch (SocketException e) {
                Log.e(TAG, "Failed to create test server socket: " + e);
                return null;
            }
            return socket;
        }

        public void clearServerReply() {
            setServerReply(null);
        }

        public void setServerReply(byte[] reply) {
            synchronized (mLock) {
                mReply = reply;
                mRcvd = 0;
                mSent = 0;
            }
        }

        public InetAddress getAddress() { return mAddress; }
        public int getPort() { return mPort; }
        public int numRequestsReceived() { synchronized (mLock) { return mRcvd; } }
        public int numRepliesSent() { synchronized (mLock) { return mSent; } }
    }
}