﻿# -*- coding:utf-8 -*-
#
# W600 flash download script
# Copyright (c) 2018 Winner Micro Electronic Design Co., Ltd.
# All rights reserved.
# 
# Python 3.0    (Special 3.4)
#
# pip install pyserial
#   pip install pyprind
#   pip install xmodem
#   pip install enum34    # for 2.7
# 
## pip install PyInstaller
## pyinstaller -F download.py -i Downloads_folder_512px.ico
#

import sys, getopt
import os
from io import UnsupportedOperation
try:
    supported = ('PYCHARM_HOSTED' in os.environ or
                 os.isatty(sys.stdout.fileno()))

# a fix for IPython notebook "IOStream has no fileno."
except(UnsupportedOperation):
    supported = True

if supported:
    USE_PYPRIND = 1;    
else:
    USE_PYPRIND = 0;

import serial
import struct
import platform
if USE_PYPRIND == 1:
    import pyprind
import time
from wmxmodem import XMODEM1k
from wmxmodem import XMODEM
import threading
from time import ctime,sleep
from enum import IntEnum
import threading

ENABLE_XMODEM_1K = True;
COM_NAME="COM1";
IMG_FILE="WM_W600_GZ.img";
reset_success=0;
BAUD_SPEED=115200;
#SUPPORT_BAUDRATE = ( "600", "1200", "1800", "2400", "4800", "19200", "38400", "57600", "115200", "230400", "460800", "921600", "1000000", "1500000", "2000000" )
SUPPORT_BAUDRATE = ( "9600", "19200", "38400", "57600", "115200", "230400", "460800", "921600", "1000000", "1500000", "2000000" )
DOWNLOAD_BAUDRATE_IDX = 4    # 115200
get_MAC_from_MCU = False;

Major_version=0
Minor_version=3
Reversion_num=2

class m_err(IntEnum):
    success = 0,
    err_serial_obj_init = -1,
    err_serial_open = -2,
    err_serial_read = -3,
    err_restart_mcu = -4,
    err_burn_firmware = -5,
    err_traverse_br_max = -6

class WMDownload(object):
    
    if platform.system() == 'Windows':
        DEFAULT_PORT = "COM1"
    else:
        DEFAULT_PORT = "/dev/ttyUSB0"
    DEFAULT_BAUD = 115200
    DEFAULT_TIMEOUT = 0.3
    DEFAULT_IMAGE = "../Bin/WM_W600_GZ.img"
    
    def __init__(self, port=DEFAULT_PORT, baud=DEFAULT_BAUD, timeout=DEFAULT_TIMEOUT, image=DEFAULT_IMAGE):
        try:
            self.image = image
            self.ser = serial.Serial(port, baud, timeout=timeout)
            self.ser.rts = False;
            if USE_PYPRIND == 1:
                statinfo_bin = os.stat(image)
                if (True == ENABLE_XMODEM_1K):
                    self.bar_user = pyprind.ProgBar(statinfo_bin.st_size/1024+2, stream=1)
                elif (False == ENABLE_XMODEM_1K):
                    self.bar_user = pyprind.ProgBar(statinfo_bin.st_size/128+2, stream=1)
        except serial.serialutil.SerialException:
            print('[E] Serial %s is used by other or INIT error' % port)
            sys.stdout.flush();
            os._exit(m_err['err_serial_obj_init'])
            
    def update_img_to_FLS(self):
        img_tmp = self.image;
        self.image = img_tmp.replace("gz.img", "FLS");
    
    def rtscts(self, rtscts=True):
        self.ser.rtscts = rtscts;
    
    def image_path(self):
        return self.image
    
    def set_port_baudrate(self, baud):
        self.ser.baudrate = baud;
        if self.ser.is_open:
            self.ser._reconfigure_port()
        # when switch baudrate, the data in buffer is invalid, clear it.
        #self.ser.flushInput();
    
    def set_timeout(self, timeout):
        self.ser.timeout  = timeout
    
    def getc(self, size, timeout=1):
        try:
            local_timeout = self.ser.timeout;
            self.ser.timeout = timeout;
            ret=self.ser.read(size);
            self.ser.timeout = local_timeout;
        except serial.serialutil.SerialException:
            print("[E] OFFline Serial %s" % self.ser.port);
            sys.stdout.flush();
            os._exit(m_err["err_serial_read"])
        return ret
    
    def putc(self, data, timeout=1):
        self.ser.timeout = timeout
        try:
            self.ser.write(data);
            self.ser.flush();
        except serial.serialutil.SerialTimeoutException:
            # do nothing
            # try it again
            return 
        
    def rst_putc(self, data, timeout = 1):
        self.ser.timeout = timeout
        try:
            self.ser.write(data);
        except serial.serialutil.SerialException as err:
            print('[ ', sys._getframe().f_lineno, ' ]', err)
            # do nothing
            # maybe serial is closed
            return
    
    def rst_flushinput(self):
        try:
            self.ser.flushInput()
        except serial.serialutil.SerialException:
            # do nothint
            # maybe serial is closed
            return
    
    def putc_bar(self, data, timeout=1):
        self.ser.timeout = timeout
        if USE_PYPRIND == 1:
            self.bar_user.update()
        return self.ser.write(data)
        
    def reset_bar(self):
        image = self.image;
        if USE_PYPRIND == 1:
            statinfo_bin = os.stat(image);
            if (True == ENABLE_XMODEM_1K):
                self.bar_user = pyprind.ProgBar(statinfo_bin.st_size/1024+2, stream=1);
            elif (False == ENABLE_XMODEM_1K):
                self.bar_user = pyprind.ProgBar(statinfo_bin.st_size/128+2, stream=1);            
    
    def open(self):
        try:
            self.ser.open()
        except serial.serialutil.SerialException:
            print("[E] Open Serial %s error" % self.ser.port)
            sys.stdout.flush();
            os._exit(m_err["err_serial_open"])
    
    def close(self):
        self.ser.flush()
        self.ser.flushInput()
        self.ser.close()
    
    def info(self):
        return (self.ser.port , self.ser.baudrate)

def print_progress_bar(current, total, width=30):
    count=width/total*current;
    left=width-count;
    level=current*100/total
    sys.stdout.write("%3d"%(level))
    #sys.stdout.flush();
    sys.stdout.write('% [');
    #sys.stdout.flush();
    while (count > 0):
        sys.stdout.write('#');
        #sys.stdout.flush();
        count = count - 1;
    while (left > 0):
        sys.stdout.write(' ');
        #sys.stdout.flush();
        left = left - 1;
    sys.stdout.write('] 100%\r');
    sys.stdout.flush();
    
def print_point(count):
    if ((count - 1) % 20 == 0 and count > 1):
        print();
    elif ((count - 1) % 10 == 0):
        print("   ", end='');
    elif ((count - 1) % 5 == 0):
        print(' ', end='');
    print('. ', end='');
    sys.stdout.flush();
    
def wait_for_secs(seconds, print_flag=1):
    if (print_flag == 1):
        print("Please wait for about %s seconds before module restart..." % seconds)
    #print("Elapse ", end='')
    count = 1;
    while (count <= seconds):
        #print_progress_bar(count, seconds);
        print_point(count);
        time.sleep(1);
        count = count + 1;
    print();
    #for pro in pyprind.prog_bar(range(50)):
    #    time.sleep(0.2)
    #print("FINISH ")
    
def print_version():
    print("download %d.%d.%d" %(Major_version, Minor_version, Reversion_num));
    
def usage():
    print_version();
    print('USAGE:')
    print('win:python download.py [-c|--console= COM] [-f|--image-file= image]')
    print('Linux:python3 download.py [-c|--console= COM] [-f|--image-file= image]')
    print()
    print('\t-c|--console=    COM         the console used by download function.')
    print('\t                             default: \"COM1\" for windows, \"ttyUSB0\" for linux.')
    print('\t-b|--baud-speed= BAUD_SPEED  the baud speed selected by user.')
    print('\t                             default: 115200')
    print('\t-f|--image-file= IMAGE       the image file used by download function transmitted to the DUT.')
    print('\t                             default: \"../Bin/WM_W600_GZ.img\"')
    print('\t-h|--help                    get this help information')
    print('\t-V|--version                 get version number')
    print()
    
def get_args():
    opts, args = getopt.getopt(sys.argv[1:], "c:b:f:hV", ["console=", "baud-speed=", "image-file=", "help", "version"])
    for op, value in opts:
        #print(op)
        if "-c" == op or "--console" == op:
            global COM_NAME
            COM_NAME = value
        elif "-b" == op or "--baud-speed" == op:
            global BAUD_SPEED
            BAUD_SPEED = value
        elif "-f" == op or "--image-file" == op:
            global IMG_FILE
            IMG_FILE = value
        elif "-h" == op or "--help" == op:
            usage()
            sys.stdout.flush();
            sys.exit(m_err['success']);
        elif "-V" == op or "--version" == op:
            print_version();
            sys.stdout.flush();
            sys.exit(m_err['success']);
              
def get_baud_first_index():
    global BAUD_SPEED
    #print("BAUD_SPEED: %s" % BAUD_SPEED);
    if str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[0]):
        return 0;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[1]):
        return 1;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[2]):
        return 2;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[3]):
        return 3;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[4]):
        return 4;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[5]):
        return 5;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[6]):
        return 6;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[7]):
        return 7;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[8]):
        return 8;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[9]):
        return 9;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[10]):
        return 10;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[11]):
        return 11;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[12]):
        return 12;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[13]):
        return 13;
    elif str(BAUD_SPEED) == str(SUPPORT_BAUDRATE[14]):
        return 14;
    return 0;

CANCEL_THREAD = False;
BR_INDEX = 0;
MUTEX_SERIAL = threading.RLock();
WAITING_4_START_XMODEM = False;
SERIAL_OBJ=1
STOP_TRAVERSE_BAUDRATE_LIST = False;
GET_OK = False;
RESET_RTS_flag = False;

def get_3_C_symbol(serial, count, timeout):
    global WAITING_4_START_XMODEM;
    global CANCEL_THREAD;
    C_count = 0;
    read_count = 0;
    while read_count < count:
        c = serial.getc(1, timeout);
        if (b'C' == c):
            C_count += 1;
        if (C_count == 3) or C_count == count:
            #print("get XXX 'C' symbol(s)");
            sys.stdout.flush();
            #WAITING_4_START_XMODEM = True;
            #CANCEL_THREAD = True;
            return True;
        read_count += 1;
    return False;

def try_to_get_C_on_higher_baudrate(serial, count):
    return get_3_C_symbol(serial, count, 0.6);

def try_to_reset_with_higher_baudrate(serial):
    time.sleep(0.2)
    spend_time = 30;
    local_baudrate = 115200;
    if (False == GET_OK and False == get_MAC_from_MCU):
        print("The program did not find the right baud rate used by the DevBoard.");
        sys.stdout.flush();
        return spend_time, local_baudrate;
    elif DOWNLOAD_BAUDRATE_IDX == 6:        # 460800
        serial.putc(bytes.fromhex('210a0007003100000000080700'))    # chagne to 460800
        spend_time = 12;    # 9.410054922103882
        local_baudrate = 460800;
    elif DOWNLOAD_BAUDRATE_IDX == 7:        # 921600
        serial.putc(bytes.fromhex('210a005d503100000000100e00'))    # chagne to 921600
        spend_time = 9;        # 6.969833850860596
        local_baudrate = 921600;
    elif DOWNLOAD_BAUDRATE_IDX == 8:        # 1M
        serial.putc(bytes.fromhex('210a005e3d3100000040420f00'))    # chagne to 1M
        spend_time = 9;        # 6.762527942657471
        local_baudrate = 1000000;
    elif DOWNLOAD_BAUDRATE_IDX == 10:       # 2M
        serial.putc(bytes.fromhex('210a00ef2a3100000080841e00'))    # chagne to 2M
        spend_time = 8;        # 5.630611896514893
        local_baudrate = 2000000;
    else:
        return 30, 115200;
    time.sleep(0.2);
    serial.set_port_baudrate(SUPPORT_BAUDRATE[DOWNLOAD_BAUDRATE_IDX]);
    time.sleep(0.2);
    serial.ser.flushInput();
    time.sleep(0.3);
    ret = try_to_get_C_on_higher_baudrate(serial, 3);
    if (False == ret):
        serial.putc(bytes.fromhex('210a00974b3100000000c20100'))    # chagne to 115200
        serial.set_port_baudrate(115200);
        print("Use 115200 to send the firmware.");
        sys.stdout.flush();
        return 30, 115200;
    #print("Use %r to send the firmware." % local_baudrate);
    #sys.stdout.flush();
    return spend_time, local_baudrate;

def send_esc_thread():
    global MUTEX_SERIAL
    global GET_OK
    serial = SERIAL_OBJ;
    esc_count = 0;
    print("Start send ESC thread.");
    sys.stdout.flush();
    while CANCEL_THREAD == False:
        baudrate = 115200;      
        MUTEX_SERIAL.acquire();
        #if GET_OK == True or True == RESET_RTS_flag:
        (port, baudrate) = serial.info();
        if (baudrate != 115200):
            serial.set_port_baudrate(115200);
        serial.putc(struct.pack('<B', 27))    # ESC     0x1B
        serial.putc(struct.pack('<B', 27))    # ESC     0x1B
        serial.putc(struct.pack('<B', 27))    # ESC     0x1B
        MUTEX_SERIAL.release();
        time.sleep(0.01);
        #if (int(baudrate) >= 115200):
         #   time.sleep(0.01);
        
def restart_mcu_thread():
    serial = SERIAL_OBJ;
    print("Start restart thread.");
    sys.stdout.flush();
    #time.sleep(0.01);
    sleep_count = 0;
    global BAUD_SPEED
    global GET_OK
    global BR_INDEX
    idx = get_baud_first_index();
    prev_br_index = -1;
    print_transver_br_list_flag = False;
    send_count = 0;
    cycle_count = 0;
    global STOP_TRAVERSE_BAUDRATE_LIST;
    stop_baudrate = "115200";
    global RESET_RTS_flag;
    count_increase = 0;
    count_increase_flag = False;
    #first_flush_serial_input_flag = False;
    
    while CANCEL_THREAD == False:
        global MUTEX_SERIAL
        if cycle_count == 0:
            print("First try cycle...");
            sys.stdout.flush();
        if cycle_count + count_increase == len(SUPPORT_BAUDRATE):
            print("Second try cycle...");
            sys.stdout.flush();
        elif cycle_count + count_increase == len(SUPPORT_BAUDRATE) * 2:
            print("Third try cycle...");
            sys.stdout.flush();
        elif cycle_count + count_increase >= len(SUPPORT_BAUDRATE) * 3:
            print("!!! PLEASE BURN FIRMWARE MANUALLY !!!");
            sys.stdout.flush();
            serial.close();
            os._exit(m_err['err_traverse_br_max']);
            break;
        
        MUTEX_SERIAL.acquire();
        BR_INDEX = idx;
        GET_OK = False;
        br = SUPPORT_BAUDRATE[ BR_INDEX % len(SUPPORT_BAUDRATE) ];
        if prev_br_index != BR_INDEX:
            if str(br) == str(BAUD_SPEED):
                print("Try to open Serial with baud speed: %s..." % br);
            else:
                print("Try to re-open Serial with baud speed: %s..." % br);
                
            if (False == print_transver_br_list_flag and True == STOP_TRAVERSE_BAUDRATE_LIST):
                print("Get printable symbol from the MCU. Stop Traverse baudrate lists.");
                sys.stdout.flush();
                print_transver_br_list_flag = True;
            prev_br_index = BR_INDEX;
            sys.stdout.flush();
            
        if (False == STOP_TRAVERSE_BAUDRATE_LIST):
            (port, baudr) = serial.info();
            if (str(baudr) != br):
                serial.set_port_baudrate(SUPPORT_BAUDRATE[ BR_INDEX % len(SUPPORT_BAUDRATE) ]);
                time.sleep(0.3);
        else:
            (port, baudr) = serial.info();
            if (str(baudr) != stop_baudrate):
                serial.set_port_baudrate(stop_baudrate);
                time.sleep(0.1);

        #if (False == first_flush_serial_input_flag):
        serial.ser.flushInput();
        serial.ser.flush();
        #    first_flush_serial_input_flag = True;
        serial.rst_putc(struct.pack('<B', 97))          # 'a'   0x61
        serial.rst_putc(struct.pack('<B', 116))         # 't'   0x74
        serial.rst_putc(struct.pack('<B', 43))          # '+'   0x2B
        serial.rst_putc(struct.pack('<B', 122))         # 'z'   0x7A
        serial.rst_putc(struct.pack('<B', 13))          # '\r'  0x0D
        serial.rst_putc(struct.pack('<B', 10))          # '\n'  0x0A
        max_esc_count = 50;

        if(GET_OK == False):
            ret = get_ok_symbol(serial);
            if (True == ret):
                esc_count = 0;
                (lport, lbaudr) = serial.info();
                if (str(lbaudr) != "115200"):
                    serial.set_port_baudrate("115200");
                    time.sleep(0.01);
                while (esc_count < max_esc_count):
                    serial.putc(struct.pack('<B', 27))    # ESC     0x1B
                    serial.putc(struct.pack('<B', 27))    # ESC     0x1B
                    serial.putc(struct.pack('<B', 27))    # ESC     0x1B
                    esc_count += 1;
                STOP_TRAVERSE_BAUDRATE_LIST = True;
                stop_baudrate = SUPPORT_BAUDRATE[ BR_INDEX % len(SUPPORT_BAUDRATE) ];
                GET_OK = True;
        
        if(GET_OK == False):
            # ignore first_flush_serial_input_flag
            serial.ser.flushInput();
            serial.ser.flush();
            serial.rst_putc(struct.pack('<B', 97))          # 'a'   0x61
            serial.rst_putc(struct.pack('<B', 116))         # 't'   0x74
            serial.rst_putc(struct.pack('<B', 43))          # '+'   0x2B
            serial.rst_putc(struct.pack('<B', 122))         # 'z'   0x7A
            serial.rst_putc(struct.pack('<B', 13))          # '\r'  0x0D
            serial.rst_putc(struct.pack('<B', 10))          # '\n'  0x0A

            ret = get_ok_symbol(serial);
            if (True == ret):
                esc_count = 0;
                (llport, llbaudr) = serial.info();
                if (str(llbaudr) != "115200"):
                    serial.set_port_baudrate("115200");
                    time.sleep(0.01);
                while (esc_count < max_esc_count):
                    serial.putc(struct.pack('<B', 27))    # ESC     0x1B
                    serial.putc(struct.pack('<B', 27))    # ESC     0x1B
                    serial.putc(struct.pack('<B', 27))    # ESC     0x1B
                    esc_count += 1;
                STOP_TRAVERSE_BAUDRATE_LIST = True;
                stop_baudrate = SUPPORT_BAUDRATE[ BR_INDEX % len(SUPPORT_BAUDRATE) ];
                GET_OK = True;

        if (False == STOP_TRAVERSE_BAUDRATE_LIST):
            idx = idx + 1;
        else:
            print(".", end='');
            if (False == count_increase_flag):
                count_increase_flag = True;
                count_increase = 5;
            sys.stdout.flush();

        if (send_count >= 3):
            serial.rtscts(True);
            time.sleep(0.1);
            serial.rtscts(False);
            time.sleep(0.1);
            RESET_RTS_flag = True;
            send_count = 0;
        else:
            RESET_RTS_flag = False;
        send_count += 1;
        
        MUTEX_SERIAL.release();

        if (True == RESET_RTS_flag):
            time.sleep(0.3);
        #if (str(baudr) >= str(115200)):
        time.sleep(1);   # 115200
        
        cycle_count = cycle_count + 1;

MAX_C_COUNT = 1;
MAX_P_COUNT = 1;

def get_3_P_symbol(serial, timeout, count):
    global WAITING_4_START_XMODEM;
    P_count = 0;
    read_count = 0;
    while read_count < 1:
        c = serial.getc(1, timeout);
        if (b'P' == c):
            P_count += 1;
        if (P_count == count):
            #print("get XXX 'P' symbol(s)");
            sys.stdout.flush();
            #WAITING_4_START_XMODEM = True;
            return True;
        read_count += 1;
    return False;

def get_ok_symbol(serial):
    c = serial.getc(1, 0.1);
    if(b'+' == c):
        c = serial.getc(1, 0.1);
        if(b'O' == c):
            c = serial.getc(1, 0.01);
            if (b'K' == c):
                return True;
    return False;

def read_char_from_serial_thread():
    global CANCEL_THREAD;
    global STOP_TRAVERSE_BAUDRATE_LIST;
    global WAITING_4_START_XMODEM;
    global MUTEX_SERIAL;
    global GET_OK;
    global get_MAC_from_MCU;
    continue_C_count = 0;
    continue_P_count = 0;
    max_continue_C_count = 3;
    max_coutinue_P_count = 3;
    
    serial = SERIAL_OBJ;
    print("Start serial read thread.");
    sys.stdout.flush();
    while (True):
        c = b'';
        #if MUTEX_SERIAL.acquire(1):
        MUTEX_SERIAL.acquire();

        (port, baudr) = serial.info();     
        if (baudr != 115200):
            serial.set_port_baudrate(115200);
        c = serial.getc(1, 0.001);
        if (b'C' == c or b'P' == c):
            serial.ser.flushInput();
            serial.putc(bytes.fromhex('210600ea2d38000000'))    # Get MAC address from ROM or SecBoot
            mac_bytes = serial.getc(20, 0.5);
            mac_str = mac_bytes.decode(encoding='utf-8');
            #print("mac: %r" % mac_str);
            #sys.stdout.flush();
            Sec_mac_pos = mac_str.find("MAC:");
            ROM_mac_pos = mac_str.find("Mac:");
            if (-1 != Sec_mac_pos or -1 != ROM_mac_pos):
                if (-1 != Sec_mac_pos):
                    mac_pos = Sec_mac_pos;
                elif (-1 != ROM_mac_pos):
                    mac_pos = ROM_mac_pos;
                mac_pos += 4;
                mac_v_index = 0;
                success_count = 0;
                while (mac_v_index <= 12):
                    v = mac_str[mac_pos + mac_v_index : mac_pos + mac_v_index + 1];
                    #print("v: %r" % v);
                    if ( 'A' <= v and v <= 'F' ) or ('0' <= v and v <= '9'):
                        success_count += 1;
                        #print("success_count %r" % success_count);
                    mac_v_index += 1;
                if (success_count == 12):
                    if (-1 != Sec_mac_pos):
                        print("Get MAC Address from SecBoot phase.");
                    elif (-1 != ROM_mac_pos):
                        print("Get MAC Address from ROM phase.");
                    get_MAC_from_MCU = True;
                    sys.stdout.flush();
                    CANCEL_THREAD = True;
                    WAITING_4_START_XMODEM = True;
        MUTEX_SERIAL.release();
        
        if (CANCEL_THREAD == True):
            break;
        
        time.sleep(0.01);

def start_threads(serial):
    global SERIAL_OBJ;
    SERIAL_OBJ = serial;
    
    t1 = threading.Thread(target=restart_mcu_thread, args=());
    t2 = threading.Thread(target=send_esc_thread, args=());
    t3 = threading.Thread(target=read_char_from_serial_thread, args=());
    
    threads = []
    threads.append(t2);
    threads.append(t3);
    threads.append(t1);
    
    for t in threads:
        t.setDaemon(True);
        t.start();
        #time.sleep(0.05);

def get_max_sleep_seconds_via_baudrate(index):
    idx = (index) % len(SUPPORT_BAUDRATE);
    sleep_seconds = 1;
    if (0 == idx):          # 9600
        sleep_seconds = 1;
    elif (1 == idx):        # 19200
        sleep_seconds = 1;
    elif (2 == idx):        # 38400
        sleep_seconds = 1;
    elif (3 == idx):        # 57600
        sleep_seconds = 1;
    elif (4 == idx):        # 115200
        sleep_seconds = 1;
    #elif (5 == idx):        # 230400
    #    sleep_seconds = 6;
    #elif (6 == idx):        # 460800
    #    sleep_seconds = 6;
    #elif (7 == idx):        # 921600
    #    sleep_seconds = 6;
    #elif (8 == idx):        # 1000000
    #    sleep_seconds = 6;
    #elif (9 == idx):        # 1500000
    #    sleep_seconds = 6;
    #elif (10 == idx):        # 2000000
    #    sleep_seconds = 6;
    return sleep_seconds;


def restart_w600(serial):
    idx = get_baud_first_index();
    sys.stdout.flush();
    count = 0;
    change_fls_flag = False;
    global BR_INDEX;
    read_timeout = 0.03;
    max_Symbol_spends_time = 5 + 1;

    start_threads(serial);
    
    try_to_get_all_c_response_flag = False;
    
    while True:
        if (True == try_to_get_all_c_response_flag):
            break;
        
        if (WAITING_4_START_XMODEM == True):
            ret = True;
        else:
            ret = False;
            
        if ret == True:
            C_count = 0;
            while True:
                c = serial.getc(1, read_timeout)
                if c == b'P':   # normally, after seconds, we will get 'C' from the MCU running in ROM.
                    if (False == change_fls_flag):
                        # no secboot, try to change to FLS file
                        print("Use FLS file to upload into the MCU.");
                        sys.stdout.flush();
                        serial.update_img_to_FLS();
                        change_fls_flag = True;
                        # ROM program maybe take 5 seconds to send 'P', and the number of 'P' is 15.
                        break;
                elif c == b'C':
                    global DOWNLOAD_BAUDRATE_IDX
                    if (True == RESET_RTS_flag):
                        print("Maybe restart the device via RST pin.");
                        sys.stdout.flush();
                        DOWNLOAD_BAUDRATE_IDX = 4;
                        # no need to set baud rate to 115200, because current baudrate is 115200,
                        # because it runs in secboot.
                    else:
                        DOWNLOAD_BAUDRATE_IDX = (BR_INDEX) % len(SUPPORT_BAUDRATE);
                    print("The target is waiting for the firmware file...");
                    sys.stdout.flush();
                    test_cnt = 0;
                    upload_baudrate = 115200;
                    MUTEX_SERIAL.acquire();
                    (seconds, upload_baudrate) = try_to_reset_with_higher_baudrate(serial);
                    MUTEX_SERIAL.release();
                    print("This maybe take ", end='');
                    sys.stdout.flush();
                    print("%s" % seconds, end='');
                    sys.stdout.flush();
                    print(" seconds with %s baud rate..." % upload_baudrate);
                    sys.stdout.flush();
                    break
                else:
                    serial.set_port_baudrate(115200);
                C_count += 1;
                if (C_count > ((max_Symbol_spends_time) / read_timeout)):
                    print("!!! DO NOT GET C symbol from MCU at %r !!!" % int(SUPPORT_BAUDRATE[BR_INDEX % len(SUPPORT_BAUDRATE)]));
                    ret = False;
                    try_to_get_all_c_response_flag = True;
                    return ret;

        if ret == False:
            if (False == WAITING_4_START_XMODEM):
                time.sleep(1);
        else:
            break;

    return ret;
    
def upgrade_firmware(serial):
    print("Start upgrade %s "  % serial.image_path());
    sys.stdout.flush();
    try:
        stream = open(serial.image_path(), 'rb+')
    except IOError:
        print("Can't open %s file." % serial.image_path());
        sys.stdout.flush();
        serial.close()
        raise
    else:        
        serial.set_timeout(1)
        serial.rst_flushinput();
        sys.stdout.flush();
 #       time.sleep(0.5)
        print("Please wait for upgrade ...");
        sys.stdout.flush();
        if (True == ENABLE_XMODEM_1K):
            modem = XMODEM1k(serial.getc, serial.putc_bar, serial.reset_bar, serial.update_img_to_FLS, serial.image_path);
        elif (False == ENABLE_XMODEM_1K):
            modem = XMODEM(serial.getc, serial.putc_bar, serial.reset_bar, serial.update_img_to_FLS, serial.image_path);
        result = modem.send_file(serial.image_path())
        if result:
            print("Upgrade %s image success!" % serial.image_path());
            sys.stdout.flush();
            print("Please wait for about 10 seconds before uncompress & restart...")
            wait_for_secs(10, print_flag=0)
            
            # restart TB01
            serial.rtscts(True);
            time.sleep(0.4);
            serial.rtscts(False);
            time.sleep(0.1);
        else:
            print("Upgrade %s image fail!" % serial.image_path());
            sys.stdout.flush();
            sys.exit(m_err['err_burn_firmware']);
        serial.close()
    
def main(argv):
    get_args()
    
    download = WMDownload(port=COM_NAME, baud=BAUD_SPEED, image=IMG_FILE)
    
    #print('')
    sys.stdout.flush();
    print("Serial open success! com: %s, baudrate: %s." % download.info())
    sys.stdout.flush();
    print('Waiting for restarting device ...')
    sys.stdout.flush();
    
    download.set_timeout(0.1)
    
    if (True == restart_w600(download)):
        upgrade_firmware(download)
    else:
        print("Restart MCU failure.");
        sys.stdout.flush();
        os._exit(m_err['err_restart_mcu']);
    
if __name__ == '__main__':
    main(sys.argv)
