当前位置:首页 > 科技数码 > 正文

python:matplotlib数据可视化(下)

摘要: python:matplotlib数据可视化(下)最佳答案53678位专家为你答疑解惑python:matplotlib数据可视...

python:matplotlib数据可视化(下)

最佳答案 53678位专家为你答疑解惑

python:matplotlib数据可视化(下)

python:matplotlib数据可视化(上)

第十八章 注解股票图表的最后价格

在这个 Matplotlib 教程中,我们将展示如何跟踪股票的最后价格的示例,通过将其注解到轴域的右侧,就像许多图表应用程序会做的那样。

虽然人们喜欢在他们的实时图表中看到历史价格,他们也想看到最新的价格。 大多数应用程序做的是,在价格的y轴高度处注释最后价格,然后突出显示它,并在价格变化时,在框中将其略微移动。 使用我们最近学习的注解教程,我们可以添加一个bbox

我们的核心代码是:

bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1)ax1.annotate(str(closep[-1]), (date[-1], closep[-1]),             xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)

我们使用ax1.annotate来放置最后价格的字符串值。 我们不在这里使用它,但我们将要注解的点指定为图上最后一个点。 接下来,我们使用xytext将我们的文本放置到特定位置。 我们将它的y坐标指定为最后一个点的y坐标,x坐标指定为最后一个点的x坐标,再加上几个点。我们这样做是为了将它移出图表。 将文本放在图形外面就足够了,但现在它只是一些浮动文本。

我们使用bbox参数在文本周围创建一个框。 我们使用bbox_props创建一个属性字典,包含盒子样式,然后是白色(w)前景色,黑色(k)边框颜色并且线宽为 1。 更多框样式请参阅 matplotlib 注解文档。

最后,这个注解向右移动,需要我们使用subplots_adjust来创建一些新空间:

plt.subplots_adjust(left=0.11, bottom=0.24, right=0.87, top=0.90, wspace=0.2, hspace=0)

这里的完整代码如下:

import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)def bytespdate2num(fmt, encoding='utf-8'):    strconverter=mdates.strpdate2num(fmt)    def bytesconverter(b):        s=b.decode(encoding)        return strconverter(s)    return bytesconverterdef graph_data(stock):    fig=plt.figure()    ax1=plt.subplot2grid((1,1), (0,0))    stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1m/csv'    source_code=urllib.request.urlopen(stock_price_url).read().decode()    stock_data=[]    split_source=source_code.split('\n')    for line in split_source:        split_line=line.split(',')        if len(split_line)==6:            if 'values' not in line and 'labels' not in line:                stock_data.append(line)    date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data,                                                          delimiter=',',                                                          unpack=True,                                                          converters={0: bytespdate2num('%Y%m%d')})    x=0    y=len(date)    ohlc=[]    while x < y:        append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]        ohlc.append(append_me)        x+=1    candlestick_ohlc(ax1, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')    for label in ax1.xaxis.get_ticklabels():        label.set_rotation(45)    ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))    ax1.xaxis.set_major_locator(mticker.MaxNLocator(10))    ax1.grid(True)    bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1)    ax1.annotate(str(closep[-1]), (date[-1], closep[-1]),                 xytext=(date[-1]+3, closep[-1]), bbox=bbox_props)##    # Annotation example with arrow##    ax1.annotate('Bad News!',(date[11],highp[11]),##                 xytext=(0.8, 0.9), textcoords='axes fraction',##                 arrowprops=dict(facecolor='grey',color='grey'))####    ##    # Font dict example##    font_dict={'family':'serif',##                 'color':'darkred',##                 'size':15}##    # Hard coded text##    ax1.text(date[10], closep[1],'Text Example', fontdict=font_dict)    plt.xlabel('Date')    plt.ylabel('Price')    plt.title(stock)    #plt.legend()    plt.subplots_adjust(left=0.11, bottom=0.24, right=0.87, top=0.90, wspace=0.2, hspace=0)    plt.show()graph_data('EBAY')

结果为:

image第十九章 子图

在这个 Matplotlib 教程中,我们将讨论子图。 有两种处理子图的主要方法,用于在同一图上创建多个图表。 现在,我们将从一个干净的代码开始。 如果你一直关注这个教程,那么请确保保留旧的代码,或者你可以随时重新查看上一个教程的代码。

首先,让我们使用样式,创建我们的图表,然后创建一个随机创建示例绘图的函数:

import randomimport matplotlib.pyplot as pltfrom matplotlib import stylestyle.use('fivethirtyeight')fig=plt.figure()def create_plots():    xs=[]    ys=[]    for i in range(10):        x=i        y=random.randrange(10)        xs.append(x)        ys.append(y)    return xs, ys

现在,我们开始使用add_subplot方法创建子图:

ax1=fig.add_subplot(221)ax2=fig.add_subplot(222)ax3=fig.add_subplot(212)

它的工作原理是使用 3 个数字,即:行数(numRows)、列数(numCols)和绘图编号(plotNum)。

所以,221 表示两行两列的第一个位置。222 是两行两列的第二个位置。最后,212 是两行一列的第二个位置。

2x2:+-----+-----+|  1  |  2  |+-----+-----+|  3  |  4  |+-----+-----+2x1:+-----------+|     1     |+-----------+|     2     |+-----------+

译者注:原文此处表述有误,译文已更改。

译者注:221是缩写形式,仅在行数乘列数小于 10 时有效,否则要写成2,2,1

此代码结果为:

image

这就是add_subplot。 尝试一些你认为可能很有趣的配置,然后尝试使用add_subplot创建它们,直到你感到满意。

接下来,让我们介绍另一种方法,它是subplot2grid

删除或注释掉其他轴域定义,然后添加:

ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)

所以,add_subplot不能让我们使一个绘图覆盖多个位置。 但是这个新的subplot2grid可以。 所以,subplot2grid的工作方式是首先传递一个元组,它是网格形状。 我们传递了(6,1),这意味着整个图表分为六行一列。 下一个元组是左上角的起始点。 对于ax1,这是0,0,因此它起始于顶部。 接下来,我们可以选择指定rowspancolspan。 这是轴域所占的行数和列数。

6x1:          colspan=1(0,0)   +-----------+        |    ax1    | rowspan=1(1,0)   +-----------+        |           |        |    ax2    | rowspan=4        |           |        |           |(5,0)   +-----------+        |    ax3    | rowspan=1        +-----------+

结果为:

image

显然,我们在这里有一些重叠的问题,我们可以调整子图来处理它。

再次,尝试构思各种配置的子图,使用subplot2grid制作出来,直到你感到满意!

我们将继续使用subplot2grid,将它应用到我们已经逐步建立的代码中,我们将在下一个教程中继续。

第二十一章 更多指标数据

在这篇 Matplotlib 教程中,我们介绍了添加一些简单的函数来计算数据,以便我们填充我们的轴域。 一个是简单的移动均值,另一个是简单的价格 HML 计算。

这些新函数是:

def moving_average(values, window):    weights=np.repeat(1.0, window)/window    smas=np.convolve(values, weights, 'valid')    return smasdef high_minus_low(highs, lows):    return highs-lows

你不需要太过专注于理解移动均值的工作原理,我们只是对样本数据来计算它,以便可以学习更多自定义 Matplotlib 的东西。

我们还想在脚本顶部为移动均值定义一些值:

MA1=10MA2=30

下面,在我们的graph_data函数中:

ma1=moving_average(closep,MA1)ma2=moving_average(closep,MA2)start=len(date[MA2-1:])h_l=list(map(high_minus_low, highp, lowp))

在这里,我们计算两个移动均值和 HML。

我们还定义了一个『起始』点。 我们这样做是因为我们希望我们的数据排成一行。 例如,20 天的移动均值需要 20 个数据点。 这意味着我们不能在第 5 天真正计算 20 天的移动均值。 因此,当我们计算移动均值时,我们会失去一些数据。 为了处理这种数据的减法,我们使用起始变量来计算应该有多少数据。 这里,我们可以安全地使用[-start:]绘制移动均值,并且如果我们希望的话,对所有绘图进行上述步骤来排列数据。

接下来,我们可以在ax1上绘制 HML,通过这样:

ax1.plot_date(date,h_l,'-')

最后我们可以通过这样向ax3添加移动均值:

ax3.plot(date[-start:], ma1[-start:])ax3.plot(date[-start:], ma2[-start:])

我们的完整代码,包括增加我们所用的时间范围:

import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window):    weights=np.repeat(1.0, window)/window    smas=np.convolve(values, weights, 'valid')    return smasdef high_minus_low(highs, lows):    return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'):    strconverter=mdates.strpdate2num(fmt)    def bytesconverter(b):        s=b.decode(encoding)        return strconverter(s)    return bytesconverterdef graph_data(stock):    fig=plt.figure()    ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)    plt.title(stock)    ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)    plt.xlabel('Date')    plt.ylabel('Price')    ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)    stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'    source_code=urllib.request.urlopen(stock_price_url).read().decode()    stock_data=[]    split_source=source_code.split('\n')    for line in split_source:        split_line=line.split(',')        if len(split_line)==6:            if 'values' not in line and 'labels' not in line:                stock_data.append(line)    date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data,                                                          delimiter=',',                                                          unpack=True,                                                          converters={0: bytespdate2num('%Y%m%d')})    x=0    y=len(date)    ohlc=[]    while x < y:        append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]        ohlc.append(append_me)        x+=1    ma1=moving_average(closep,MA1)    ma2=moving_average(closep,MA2)    start=len(date[MA2-1:])    h_l=list(map(high_minus_low, highp, lowp))    ax1.plot_date(date,h_l,'-')    candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')    for label in ax2.xaxis.get_ticklabels():        label.set_rotation(45)    ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))    ax2.xaxis.set_major_locator(mticker.MaxNLocator(10))    ax2.grid(True)    bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1)    ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),                 xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)##    # Annotation example with arrow##    ax2.annotate('Bad News!',(date[11],highp[11]),##                 xytext=(0.8, 0.9), textcoords='axes fraction',##                 arrowprops=dict(facecolor='grey',color='grey'))####    ##    # Font dict example##    font_dict={'family':'serif',##                 'color':'darkred',##                 'size':15}##    # Hard coded text ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)    ax3.plot(date[-start:], ma1[-start:])    ax3.plot(date[-start:], ma2[-start:])    plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)    plt.show()graph_data('EBAY')

代码效果如图:

image第二十二章 自定义填充、修剪和清除

欢迎阅读另一个 Matplotlib 教程! 在本教程中,我们将清除图表,然后再做一些自定义。

我们当前的代码是:

import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window):    weights=np.repeat(1.0, window)/window    smas=np.convolve(values, weights, 'valid')    return smasdef high_minus_low(highs, lows):    return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'):    strconverter=mdates.strpdate2num(fmt)    def bytesconverter(b):        s=b.decode(encoding)        return strconverter(s)    return bytesconverterdef graph_data(stock):    fig=plt.figure()    ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)    plt.title(stock)    ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)    plt.xlabel('Date')    plt.ylabel('Price')    ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)    stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'    source_code=urllib.request.urlopen(stock_price_url).read().decode()    stock_data=[]    split_source=source_code.split('\n')    for line in split_source:        split_line=line.split(',')        if len(split_line)==6:            if 'values' not in line and 'labels' not in line:                stock_data.append(line)    date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data,                                                          delimiter=',',                                                          unpack=True,                                                          converters={0: bytespdate2num('%Y%m%d')})    x=0    y=len(date)    ohlc=[]    while x < y:        append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]        ohlc.append(append_me)        x+=1    ma1=moving_average(closep,MA1)    ma2=moving_average(closep,MA2)    start=len(date[MA2-1:])    h_l=list(map(high_minus_low, highp, lowp))    ax1.plot_date(date,h_l,'-')    candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')    for label in ax2.xaxis.get_ticklabels():        label.set_rotation(45)    ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))    ax2.xaxis.set_major_locator(mticker.MaxNLocator(10))    ax2.grid(True)    bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1)    ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),                 xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)##    # Annotation example with arrow##    ax2.annotate('Bad News!',(date[11],highp[11]),##                 xytext=(0.8, 0.9), textcoords='axes fraction',##                 arrowprops=dict(facecolor='grey',color='grey'))####    ##    # Font dict example##    font_dict={'family':'serif',##                 'color':'darkred',##                 'size':15}##    # Hard coded text ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)    ax3.plot(date[-start:], ma1[-start:])    ax3.plot(date[-start:], ma2[-start:])    plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)    plt.show()graph_data('EBAY')

现在我认为向我们的移动均值添加自定义填充是一个很好的主意。 移动均值通常用于说明价格趋势。 这个想法是,你可以计算一个快速和一个慢速的移动均值。 一般来说,移动均值用于使价格变得『平滑』。 他们总是『滞后』于价格,但是我们的想法是计算不同的速度。 移动均值越大就越『慢』。 所以这个想法是,如果『较快』的移动均值超过『较慢』的均值,那么价格就会上升,这是一件好事。 如果较快的 MA 从较慢的 MA 下方穿过,则这是下降趋势并且通常被视为坏事。 我的想法是在快速和慢速 MA 之间填充,『上升』趋势为绿色,然后下降趋势为红色。 方法如下:

ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                 where=(ma1[-start:] < ma2[-start:]),                 facecolor='r', edgecolor='r', alpha=0.5)ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                 where=(ma1[-start:] > ma2[-start:]),                 facecolor='g', edgecolor='g', alpha=0.5)

下面,我们会碰到一些我们可解决的问题:

ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))for label in ax3.xaxis.get_ticklabels():    label.set_rotation(45)plt.setp(ax1.get_xticklabels(), visible=False)plt.setp(ax2.get_xticklabels(), visible=False)

这里,我们剪切和粘贴ax2日期格式,然后我们将x刻度标签设置为false,去掉它们!

我们还可以通过在轴域定义中执行以下操作,为每个轴域提供自定义标签:

fig=plt.figure()ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)plt.title(stock)ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)plt.xlabel('Date')plt.ylabel('Price')ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)

接下来,我们可以看到,我们y刻度有许多数字,经常互相覆盖。 我们也看到轴之间互相重叠。 我们可以这样:

ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5, prune='lower'))

所以,这里发生的是,我们通过首先将nbins设置为 5 来修改我们的y轴对象。这意味着我们显示的标签最多为 5 个。然后我们还可以『修剪』标签,因此,在我们这里, 我们修剪底部标签,这会使它消失,所以现在不会有任何文本重叠。 我们仍然可能打算修剪ax2的顶部标签,但这里是我们目前为止的源代码:

当前的源码:

import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window):    weights=np.repeat(1.0, window)/window    smas=np.convolve(values, weights, 'valid')    return smasdef high_minus_low(highs, lows):    return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'):    strconverter=mdates.strpdate2num(fmt)    def bytesconverter(b):        s=b.decode(encoding)        return strconverter(s)    return bytesconverterdef graph_data(stock):    fig=plt.figure()    ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)    plt.title(stock)    plt.ylabel('H-L')    ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)    plt.ylabel('Price')    ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)    plt.ylabel('MAvgs')    stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'    source_code=urllib.request.urlopen(stock_price_url).read().decode()    stock_data=[]    split_source=source_code.split('\n')    for line in split_source:        split_line=line.split(',')        if len(split_line)==6:            if 'values' not in line and 'labels' not in line:                stock_data.append(line)    date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data,                                                          delimiter=',',                                                          unpack=True,                                                          converters={0: bytespdate2num('%Y%m%d')})    x=0    y=len(date)    ohlc=[]    while x < y:        append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]        ohlc.append(append_me)        x+=1    ma1=moving_average(closep,MA1)    ma2=moving_average(closep,MA2)    start=len(date[MA2-1:])    h_l=list(map(high_minus_low, highp, lowp))    ax1.plot_date(date,h_l,'-')    ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5, prune='lower'))    candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f')    ax2.grid(True)    bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1)    ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),                 xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)##    # Annotation example with arrow##    ax2.annotate('Bad News!',(date[11],highp[11]),##                 xytext=(0.8, 0.9), textcoords='axes fraction',##                 arrowprops=dict(facecolor='grey',color='grey'))####    ##    # Font dict example##    font_dict={'family':'serif',##                 'color':'darkred',##                 'size':15}##    # Hard coded text ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)    ax3.plot(date[-start:], ma1[-start:], linewidth=1)    ax3.plot(date[-start:], ma2[-start:], linewidth=1)    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                     where=(ma1[-start:] < ma2[-start:]),                     facecolor='r', edgecolor='r', alpha=0.5)    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                     where=(ma1[-start:] > ma2[-start:]),                     facecolor='g', edgecolor='g', alpha=0.5)    ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))    ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))    for label in ax3.xaxis.get_ticklabels():        label.set_rotation(45)    plt.setp(ax1.get_xticklabels(), visible=False)    plt.setp(ax2.get_xticklabels(), visible=False)    plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)    plt.show()graph_data('EBAY')
image

看起来好了一些,但是仍然有一些东西需要清除。

第二十三章 共享 X 轴

在这个 Matplotlib 数据可视化教程中,我们将讨论sharex选项,它允许我们在图表之间共享x轴。将sharex看做『复制 x』也许更好。

在我们开始之前,首先我们要做些修剪并在另一个轴上设置最大刻度数,如下所示:

ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))

以及

ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))

现在,让我们共享所有轴域之间的x轴。 为此,我们需要将其添加到轴域定义中:

fig=plt.figure()ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)plt.title(stock)plt.ylabel('H-L')ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)plt.ylabel('Price')ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)plt.ylabel('MAvgs')

上面,对于ax2ax3,我们添加一个新的参数,称为sharex,然后我们说,我们要与ax1共享x轴。

使用这种方式,我们可以加载图表,然后我们可以放大到一个特定的点,结果将是这样:

image

所以这意味着所有轴域沿着它们的x轴一起移动。 这很酷吧!

接下来,让我们将[-start:]应用到所有数据,所以所有轴域都起始于相同地方。 我们最终的代码为:

import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window):    weights=np.repeat(1.0, window)/window    smas=np.convolve(values, weights, 'valid')    return smasdef high_minus_low(highs, lows):    return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'):    strconverter=mdates.strpdate2num(fmt)    def bytesconverter(b):        s=b.decode(encoding)        return strconverter(s)    return bytesconverterdef graph_data(stock):    fig=plt.figure()    ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)    plt.title(stock)    plt.ylabel('H-L')    ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)    plt.ylabel('Price')    ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)    plt.ylabel('MAvgs')    stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'    source_code=urllib.request.urlopen(stock_price_url).read().decode()    stock_data=[]    split_source=source_code.split('\n')    for line in split_source:        split_line=line.split(',')        if len(split_line)==6:            if 'values' not in line and 'labels' not in line:                stock_data.append(line)    date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data,                                                          delimiter=',',                                                          unpack=True,                                                          converters={0: bytespdate2num('%Y%m%d')})    x=0    y=len(date)    ohlc=[]    while x < y:        append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]        ohlc.append(append_me)        x+=1    ma1=moving_average(closep,MA1)    ma2=moving_average(closep,MA2)    start=len(date[MA2-1:])    h_l=list(map(high_minus_low, highp, lowp))    ax1.plot_date(date[-start:],h_l[-start:],'-')    ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower'))    candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f')    ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))    ax2.grid(True)    bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1)    ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),                 xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)##    # Annotation example with arrow##    ax2.annotate('Bad News!',(date[11],highp[11]),##                 xytext=(0.8, 0.9), textcoords='axes fraction',##                 arrowprops=dict(facecolor='grey',color='grey'))####    ##    # Font dict example##    font_dict={'family':'serif',##                 'color':'darkred',##                 'size':15}##    # Hard coded text ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)    ax3.plot(date[-start:], ma1[-start:], linewidth=1)    ax3.plot(date[-start:], ma2[-start:], linewidth=1)    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                     where=(ma1[-start:] < ma2[-start:]),                     facecolor='r', edgecolor='r', alpha=0.5)    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                     where=(ma1[-start:] > ma2[-start:]),                     facecolor='g', edgecolor='g', alpha=0.5)    ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))    ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))    ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))    for label in ax3.xaxis.get_ticklabels():        label.set_rotation(45)    plt.setp(ax1.get_xticklabels(), visible=False)    plt.setp(ax2.get_xticklabels(), visible=False)    plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)    plt.show()graph_data('EBAY')

下面我们会讨论如何创建多个y轴。

第二十四章 多个 Y 轴

在这篇 Matplotlib 教程中,我们将介绍如何在同一子图上使用多个 Y 轴。 在我们的例子中,我们有兴趣在同一个图表及同一个子图上绘制股票价格和交易量。

为此,首先我们需要定义一个新的轴域,但是这个轴域是ax2仅带有x轴的『双生子』。

这足以创建轴域了。我们叫它ax2v,因为这个轴域是ax2加交易量。

现在,我们在轴域上定义绘图,我们将添加:

ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4)

我们在 0 和当前交易量之间填充,给予它蓝色的前景色,然后给予它一个透明度。 我们想要应用幽冥毒,以防交易量最终覆盖其它东西,所以我们仍然可以看到这两个元素。

所以,到现在为止,我们的代码为:

import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window):    weights=np.repeat(1.0, window)/window    smas=np.convolve(values, weights, 'valid')    return smasdef high_minus_low(highs, lows):    return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'):    strconverter=mdates.strpdate2num(fmt)    def bytesconverter(b):        s=b.decode(encoding)        return strconverter(s)    return bytesconverterdef graph_data(stock):    fig=plt.figure()    ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)    plt.title(stock)    plt.ylabel('H-L')    ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)    plt.ylabel('Price')    ax2v=ax2.twinx()    ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)    plt.ylabel('MAvgs')    stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'    source_code=urllib.request.urlopen(stock_price_url).read().decode()    stock_data=[]    split_source=source_code.split('\n')    for line in split_source:        split_line=line.split(',')        if len(split_line)==6:            if 'values' not in line and 'labels' not in line:                stock_data.append(line)    date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data,                                                          delimiter=',',                                                          unpack=True,                                                          converters={0: bytespdate2num('%Y%m%d')})    x=0    y=len(date)    ohlc=[]    while x < y:        append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]        ohlc.append(append_me)        x+=1    ma1=moving_average(closep,MA1)    ma2=moving_average(closep,MA2)    start=len(date[MA2-1:])    h_l=list(map(high_minus_low, highp, lowp))    ax1.plot_date(date[-start:],h_l[-start:],'-')    ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower'))    candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f')    ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))    ax2.grid(True)    bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1)    ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),                 xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)##    # Annotation example with arrow##    ax2.annotate('Bad News!',(date[11],highp[11]),##                 xytext=(0.8, 0.9), textcoords='axes fraction',##                 arrowprops=dict(facecolor='grey',color='grey'))####    ##    # Font dict example##    font_dict={'family':'serif',##                 'color':'darkred',##                 'size':15}##    # Hard coded text ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)    ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4)    ax3.plot(date[-start:], ma1[-start:], linewidth=1)    ax3.plot(date[-start:], ma2[-start:], linewidth=1)    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                     where=(ma1[-start:] < ma2[-start:]),                     facecolor='r', edgecolor='r', alpha=0.5)    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                     where=(ma1[-start:] > ma2[-start:]),                     facecolor='g', edgecolor='g', alpha=0.5)    ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))    ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))    ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))    for label in ax3.xaxis.get_ticklabels():        label.set_rotation(45)    plt.setp(ax1.get_xticklabels(), visible=False)    plt.setp(ax2.get_xticklabels(), visible=False)    plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)    plt.show()graph_data('GOOG')

会生成:

image

太棒了,到目前为止还不错。 接下来,我们可能要删除新y轴上的标签,然后我们也可能不想让交易量占用太多空间。 没问题:

首先:

ax2v.axes.yaxis.set_ticklabels([])

上面将y刻度标签设置为一个空列表,所以不会有任何标签了。

译者注:所以将标签删除之后,添加新轴的意义是什么?直接在原轴域上绘图就可以了。

接下来,我们可能要将网格设置为false,使轴域上不会有双网格:

ax2v.grid(False)

最后,为了处理交易量占用很多空间,我们可以做以下操作:

ax2v.set_ylim(0, 3*volume.max())

所以这设置y轴显示范围从 0 到交易量的最大值的 3 倍。 这意味着,在最高点,交易量最多可占据图形的33%。 所以,增加volume.max的倍数越多,空间就越小/越少。

现在,我们的图表为:

import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window):    weights=np.repeat(1.0, window)/window    smas=np.convolve(values, weights, 'valid')    return smasdef high_minus_low(highs, lows):    return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'):    strconverter=mdates.strpdate2num(fmt)    def bytesconverter(b):        s=b.decode(encoding)        return strconverter(s)    return bytesconverterdef graph_data(stock):    fig=plt.figure()    ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)    plt.title(stock)    plt.ylabel('H-L')    ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)    plt.ylabel('Price')    ax2v=ax2.twinx()    ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)    plt.ylabel('MAvgs')    stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'    source_code=urllib.request.urlopen(stock_price_url).read().decode()    stock_data=[]    split_source=source_code.split('\n')    for line in split_source:        split_line=line.split(',')        if len(split_line)==6:            if 'values' not in line and 'labels' not in line:                stock_data.append(line)    date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data,                                                          delimiter=',',                                                          unpack=True,                                                          converters={0: bytespdate2num('%Y%m%d')})    x=0    y=len(date)    ohlc=[]    while x < y:        append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]        ohlc.append(append_me)        x+=1    ma1=moving_average(closep,MA1)    ma2=moving_average(closep,MA2)    start=len(date[MA2-1:])    h_l=list(map(high_minus_low, highp, lowp))    ax1.plot_date(date[-start:],h_l[-start:],'-')    ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower'))    candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f')    ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))    ax2.grid(True)    bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1)    ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),                 xytext=(date[-1]+5, closep[-1]), bbox=bbox_props)##    # Annotation example with arrow##    ax2.annotate('Bad News!',(date[11],highp[11]),##                 xytext=(0.8, 0.9), textcoords='axes fraction',##                 arrowprops=dict(facecolor='grey',color='grey'))####    ##    # Font dict example##    font_dict={'family':'serif',##                 'color':'darkred',##                 'size':15}##    # Hard coded text ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)    ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4)    ax2v.axes.yaxis.set_ticklabels([])    ax2v.grid(False)    ax2v.set_ylim(0, 3*volume.max())    ax3.plot(date[-start:], ma1[-start:], linewidth=1)    ax3.plot(date[-start:], ma2[-start:], linewidth=1)    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                     where=(ma1[-start:] < ma2[-start:]),                     facecolor='r', edgecolor='r', alpha=0.5)    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                     where=(ma1[-start:] > ma2[-start:]),                     facecolor='g', edgecolor='g', alpha=0.5)    ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))    ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))    ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))    for label in ax3.xaxis.get_ticklabels():        label.set_rotation(45)    plt.setp(ax1.get_xticklabels(), visible=False)    plt.setp(ax2.get_xticklabels(), visible=False)    plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)    plt.show()graph_data('GOOG')

到这里,我们差不多完成了。 这里唯一的缺陷是一个好的图例。 一些线条是显而易见的,但人们可能会好奇移动均值的参数是什么,我们这里是 10 和 30。 添加自定义图例是下一个教程中涉及的内容。

第二十五章 自定义图例

在这篇 Matplotlib 教程中,我们将讨论自定义图例。 我们已经介绍了添加图例的基础知识。

图例的主要问题通常是图例阻碍了数据的展示。 这里有几个选项。 一个选项是将图例放在轴域外,但是我们在这里有多个子图,这是非常困难的。 相反,我们将使图例稍微小一点,然后应用一个透明度。

首先,为了创建一个图例,我们需要向我们的数据添加我们想要显示在图例上的标签。

ax1.plot_date(date[-start:],h_l[-start:],'-', label='H-L')...ax2v.plot([],[], color='#0079a3', alpha=0.4, label='Volume')...ax3.plot(date[-start:], ma1[-start:], linewidth=1, label=(str(MA1)+'MA'))ax3.plot(date[-start:], ma2[-start:], linewidth=1, label=(str(MA2)+'MA'))

请注意,我们通过创建空行为交易量添加了标签。 请记住,我们不能对任何填充应用标签,所以这就是我们添加这个空行的原因。

现在,我们可以在右下角添加图例,通过在plt.show()之前执行以下操作:

ax1.legend()ax2v.legend()ax3.legend()

会生成:

image

所以,我们可以看到,图例还是占用了一些位置。 让我们更改位置,大小并添加透明度:

ax1.legend()leg=ax1.legend(loc=9, ncol=2,prop={'size':11})leg.get_frame().set_alpha(0.4)ax2v.legend()leg=ax2v.legend(loc=9, ncol=2,prop={'size':11})leg.get_frame().set_alpha(0.4)ax3.legend()leg=ax3.legend(loc=9, ncol=2,prop={'size':11})leg.get_frame().set_alpha(0.4)

所有的图例位于位置 9(上中间)。 有很多地方可放置图例,我们可以为参数传入不同的位置号码,来看看它们都位于哪里。 ncol参数允许我们指定图例中的列数。 这里只有一列,如果图例中有 2 个项目,他们将堆叠在一列中。 最后,我们将尺寸规定为更小。 之后,我们对整个图例应用0.4的透明度。

现在我们的结果为:

image

完整的代码为:

import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window):    weights=np.repeat(1.0, window)/window    smas=np.convolve(values, weights, 'valid')    return smasdef high_minus_low(highs, lows):    return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'):    strconverter=mdates.strpdate2num(fmt)    def bytesconverter(b):        s=b.decode(encoding)        return strconverter(s)    return bytesconverterdef graph_data(stock):    fig=plt.figure(facecolor='#f0f0f0')    ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)    plt.title(stock)    plt.ylabel('H-L')    ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)    plt.ylabel('Price')    ax2v=ax2.twinx()    ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)    plt.ylabel('MAvgs')    stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv'    source_code=urllib.request.urlopen(stock_price_url).read().decode()    stock_data=[]    split_source=source_code.split('\n')    for line in split_source:        split_line=line.split(',')        if len(split_line)==6:            if 'values' not in line and 'labels' not in line:                stock_data.append(line)    date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data,                                                          delimiter=',',                                                          unpack=True,                                                          converters={0: bytespdate2num('%Y%m%d')})    x=0    y=len(date)    ohlc=[]    while x < y:        append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x]        ohlc.append(append_me)        x+=1    ma1=moving_average(closep,MA1)    ma2=moving_average(closep,MA2)    start=len(date[MA2-1:])    h_l=list(map(high_minus_low, highp, lowp))    ax1.plot_date(date[-start:],h_l[-start:],'-', label='H-L')    ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower'))    candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f')    ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))    ax2.grid(True)    bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1)    ax2.annotate(str(closep[-1]), (date[-1], closep[-1]),                 xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)##    # Annotation example with arrow##    ax2.annotate('Bad News!',(date[11],highp[11]),##                 xytext=(0.8, 0.9), textcoords='axes fraction',##                 arrowprops=dict(facecolor='grey',color='grey'))####    ##    # Font dict example##    font_dict={'family':'serif',##                 'color':'darkred',##                 'size':15}##    # Hard coded text ##    ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict)    ax2v.plot([],[], color='#0079a3', alpha=0.4, label='Volume')    ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4)    ax2v.axes.yaxis.set_ticklabels([])    ax2v.grid(False)    ax2v.set_ylim(0, 3*volume.max())    ax3.plot(date[-start:], ma1[-start:], linewidth=1, label=(str(MA1)+'MA'))    ax3.plot(date[-start:], ma2[-start:], linewidth=1, label=(str(MA2)+'MA'))    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                     where=(ma1[-start:] < ma2[-start:]),                     facecolor='r', edgecolor='r', alpha=0.5)    ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:],                     where=(ma1[-start:] > ma2[-start:]),                     facecolor='g', edgecolor='g', alpha=0.5)    ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))    ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))    ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))    for label in ax3.xaxis.get_ticklabels():        label.set_rotation(45)    plt.setp(ax1.get_xticklabels(), visible=False)    plt.setp(ax2.get_xticklabels(), visible=False)    plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0)    ax1.legend()    leg=ax1.legend(loc=9, ncol=2,prop={'size':11})    leg.get_frame().set_alpha(0.4)    ax2v.legend()    leg=ax2v.legend(loc=9, ncol=2,prop={'size':11})    leg.get_frame().set_alpha(0.4)    ax3.legend()    leg=ax3.legend(loc=9, ncol=2,prop={'size':11})    leg.get_frame().set_alpha(0.4)    plt.show()    fig.savefig('google.png', facecolor=fig.get_facecolor())graph_data('GOOG')

现在我们可以看到图例,但也看到了图例下的任何信息。 还要注意额外函数fig.savefig。 这是自动保存图形的图像的方式。 我们还可以设置所保存的图形的前景色,使背景不是白色的,如我们的例子所示。

这就是目前为止,我想要显示的典型 Matplotlib 图表。 接下来,我们将涉及Basemap,它是一个 Matplotlib 扩展,用于绘制地理位置,然后我打算讲解 Matplotlib 中的 3D 图形。

第二十六章 Basemap 地理绘图

在这个 Matplotlib 教程中,我们将涉及地理绘图模块BasemapBasemap是 Matplotlib 的扩展。

为了使用Basemap,我们首先需要安装它。 为了获得Basemap,你可以从这里获取:http://matplotlib.org/basemap/users/download.html,或者你可以访问http://www.lfd.uci.edu/~gohlke/pythonlibs/。

如果你在安装Basemap时遇到问题,请查看pip安装教程。

一旦你安装了Basemap,你就可以创建地图了。 首先,让我们投影一个简单的地图。 为此,我们需要导入Basemappyplot,创建投影,至少绘制某种轮廓或数据,然后我们可以显示图形。

from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill')m.drawcoastlines()plt.show()

上面的代码结果如下:

image

这是使用 Miller 投影完成的,这只是许多Basemap投影选项之一。

第二十七章 Basemap 自定义

在这篇 Matplotlib 教程中,我们继续使用Basemap地理绘图扩展。 我们将展示一些我们可用的自定义选项。

首先,从上一个教程中获取我们的起始代码:

from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill')m.drawcoastlines()plt.show()

我们可以从放大到特定区域来开始:

from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill',            llcrnrlat=-40,            llcrnrlon=-40,            urcrnrlat=50,            urcrnrlon=75)m.drawcoastlines()plt.show()

这里的参数是:

llcrnrlat - 左下角的纬度llcrnrlon - 左下角的经度urcrnrlat - 右上角的纬度urcrnrlon - 右上角的经度

此外,坐标需要转换,其中西经和南纬坐标是负值,北纬和东经坐标是正值。

使用这些坐标,Basemap会选择它们之间的区域。

image

下面,我们要使用一些东西,类似:

m.drawcountries(linewidth=2)

这会画出国家,并使用线宽为 2 的线条生成分界线。

另一个选项是:

m.drawstates(color='b')

这会用蓝色线条画出州。

你也可以执行:

m.drawcounties(color='darkred')

这会画出国家。

所以,我们的代码是:

from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill',            llcrnrlat=-90,            llcrnrlon=-180,            urcrnrlat=90,            urcrnrlon=180)m.drawcoastlines()m.drawcountries(linewidth=2)m.drawstates(color='b')m.drawcounties(color='darkred')plt.title('Basemap Tutorial')plt.show()
image

很难说,但我们定义了美国的区县的线条。 我们可以使用放大镜放大Basemap图形,就像其他图形那样,会生成:

image

另一个有用的选项是Basemap调用中的『分辨率』选项。

m=Basemap(projection='mill',            llcrnrlat=-90,            llcrnrlon=-180,            urcrnrlat=90,            urcrnrlon=180,            resolution='l')

分辨率的选项为:

c - 粗糙l - 低h - 高f - 完整

对于更高的分辨率,你应该放大到很大,否则这可能只是浪费。

另一个选项是使用etopo()绘制地形,如:

m.etopo()

使用drawcountries方法绘制此图形会生成:

[图片上传失败...(image-eaca82-1558445064342)]

最后,有一个蓝色的大理石版本,你可以调用:

m.bluemarble()

会生成:

[图片上传失败...(image-3f9978-1558445064342)]

目前为止的代码:

from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill',            llcrnrlat=-90,            llcrnrlon=-180,            urcrnrlat=90,            urcrnrlon=180,            resolution='l')m.drawcoastlines()m.drawcountries(linewidth=2)##m.drawstates(color='b')##m.drawcounties(color='darkred')#m.fillcontinents()#m.etopo()m.bluemarble()plt.title('Basemap Tutorial')plt.show()
第二十八章 在 Basemap 中绘制坐标

欢迎阅读另一个 Matplotlib Basemap 教程。 在本教程中,我们将介绍如何绘制单个坐标,以及如何在地理区域中连接这些坐标。

首先,我们将从一些基本的起始数据开始:

from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill',            llcrnrlat=25,            llcrnrlon=-130,            urcrnrlat=50,            urcrnrlon=-60,            resolution='l')m.drawcoastlines()m.drawcountries(linewidth=2)m.drawstates(color='b')

接下来,我们可以绘制坐标,从获得它们的实际坐标开始。 记住,南纬和西经坐标需要转换为负值。 例如,纽约市是北纬40.7127西经74.0059。 我们可以在我们的程序中定义这些坐标,如:

NYClat, NYClon=40.7127, -74.0059

之后我们将这些转换为要绘制的xy坐标。

xpt, ypt=m(NYClon, NYClat)

注意这里,我们现在已经将坐标顺序翻转为lon, lat(纬度,经度)。 坐标通常以lat, lon顺序给出。 然而,在图形中,lat, long转换为y, x,我们显然不需要。 在某些时候,你必须翻转它们。 不要忘记这部分!

最后,我们可以绘制如下的坐标:

m.plot(xpt, ypt, 'c*', markersize=15)

这个图表上有一个青色的星,大小为 15。更多标记类型请参阅:Matplotlib 标记文档。

接下来,让我们再画一个位置,洛杉矶,加利福尼亚:

LAlat, LAlon=34.05, -118.25xpt, ypt=m(LAlon, LAlat)m.plot(xpt, ypt, 'g^', markersize=15)

这次我们画出一个绿色三角,执行代码会生成:

[图片上传失败...(image-8e5eff-1558445064342)]

如果我们想连接这些图块怎么办?原来,我们可以像其它 Matplotlib 图表那样实现它。

首先,我们将那些xptypt坐标保存到列表,类似这样的东西:

xs=[]ys=[]NYClat, NYClon=40.7127, -74.0059xpt, ypt=m(NYClon, NYClat)xs.append(xpt)ys.append(ypt)m.plot(xpt, ypt, 'c*', markersize=15)LAlat, LAlon=34.05, -118.25xpt, ypt=m(LAlon, LAlat)xs.append(xpt)ys.append(ypt)m.plot(xpt, ypt, 'g^', markersize=15)m.plot(xs, ys, color='r', linewidth=3, label='Flight 98')

会生成:

image

太棒了。有时我们需要以圆弧连接图上的两个坐标。如何实现呢?

m.drawgreatcircle(NYClon, NYClat, LAlon, LAlat, color='c', linewidth=3, label='Arc')

我们的完整代码为:

from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill',            llcrnrlat=25,            llcrnrlon=-130,            urcrnrlat=50,            urcrnrlon=-60,            resolution='l')m.drawcoastlines()m.drawcountries(linewidth=2)m.drawstates(color='b')#m.drawcounties(color='darkred')#m.fillcontinents()#m.etopo()#m.bluemarble()xs=[]ys=[]NYClat, NYClon=40.7127, -74.0059xpt, ypt=m(NYClon, NYClat)xs.append(xpt)ys.append(ypt)m.plot(xpt, ypt, 'c*', markersize=15)LAlat, LAlon=34.05, -118.25xpt, ypt=m(LAlon, LAlat)xs.append(xpt)ys.append(ypt)m.plot(xpt, ypt, 'g^', markersize=15)m.plot(xs, ys, color='r', linewidth=3, label='Flight 98')m.drawgreatcircle(NYClon, NYClat, LAlon, LAlat, color='c', linewidth=3, label='Arc')plt.legend(loc=4)plt.title('Basemap Tutorial')plt.show()

结果为:

image

这就是Basemap的全部了,下一章关于 Matplotlib 的 3D 绘图。

第二十九章 3D 绘图

您好,欢迎阅读 Matplotlib 教程中的 3D 绘图。 Matplotlib 已经内置了三维图形,所以我们不需要再下载任何东西。 首先,我们需要引入一些完整的模块:

from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as plt

使用axes3d是因为它需要不同种类的轴域,以便在三维中实际绘制一些东西。 下面:

fig=plt.figure()ax1=fig.add_subplot(111, projection='3d')

在这里,我们像通常一样定义图形,然后我们将ax1定义为通常的子图,只是这次使用 3D 投影。 我们需要这样做,以便提醒 Matplotlib 我们要提供三维数据。

现在让我们创建一些 3D 数据:

x=[1,2,3,4,5,6,7,8,9,10]y=[5,6,7,8,2,5,6,3,7,2]z=[1,2,6,3,2,7,3,3,7,2]

接下来,我们绘制它。 首先,让我们展示一个简单的线框示例:

ax1.plot_wireframe(x,y,z)

最后:

ax1.set_xlabel('x axis')ax1.set_ylabel('y axis')ax1.set_zlabel('z axis')plt.show()

我们完整的代码是:

from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as pltfrom matplotlib import stylestyle.use('fivethirtyeight')fig=plt.figure()ax1=fig.add_subplot(111, projection='3d')x=[1,2,3,4,5,6,7,8,9,10]y=[5,6,7,8,2,5,6,3,7,2]z=[1,2,6,3,2,7,3,3,7,2]ax1.plot_wireframe(x,y,z)ax1.set_xlabel('x axis')ax1.set_ylabel('y axis')ax1.set_zlabel('z axis')plt.show()

结果为(包括所用的样式):

image

这些 3D 图形可以进行交互。 首先,您可以使用鼠标左键单击并拖动来移动图形。 您还可以使用鼠标右键单击并拖动来放大或缩小。

第三十章 3D 散点图

欢迎阅读另一个 3D Matplotlib 教程,会涉及如何绘制三维散点图。

绘制 3D 散点图非常类似于通常的散点图以及 3D 线框图。

一个简单示例:

from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as pltfrom matplotlib import stylestyle.use('ggplot')fig=plt.figure()ax1=fig.add_subplot(111, projection='3d')x=[1,2,3,4,5,6,7,8,9,10]y=[5,6,7,8,2,5,6,3,7,2]z=[1,2,6,3,2,7,3,3,7,2]x2=[-1,-2,-3,-4,-5,-6,-7,-8,-9,-10]y2=[-5,-6,-7,-8,-2,-5,-6,-3,-7,-2]z2=[1,2,6,3,2,7,3,3,7,2]ax1.scatter(x, y, z, c='g', marker='o')ax1.scatter(x2, y2, z2, c='r', marker='o')ax1.set_xlabel('x axis')ax1.set_ylabel('y axis')ax1.set_zlabel('z axis')plt.show()

结果为:

image

要记住你可以修改这些绘图的大小和标记,就像通常的散点图那样。

第三十一章 3D 条形图

在这个 Matplotlib 教程中,我们要介绍 3D 条形图。 3D 条形图是非常独特的,因为它允许我们绘制多于 3 个维度。 不,你不能超过第三个维度来绘制,但你可以绘制多于 3 个维度。

对于条形图,你需要拥有条形的起点,条形的高度和宽度。 但对于 3D 条形图,你还有另一个选项,就是条形的深度。 大多数情况下,条形图从轴上的条形平面开始,但是你也可以通过打破此约束来添加另一个维度。 然而,我们会让它非常简单:

from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as pltimport numpy as npfrom matplotlib import stylestyle.use('ggplot')fig=plt.figure()ax1=fig.add_subplot(111, projection='3d')x3=[1,2,3,4,5,6,7,8,9,10]y3=[5,6,7,8,2,5,6,3,7,2]z3=np.zeros(10)dx=np.ones(10)dy=np.ones(10)dz=[1,2,3,4,5,6,7,8,9,10]ax1.bar3d(x3, y3, z3, dx, dy, dz)ax1.set_xlabel('x axis')ax1.set_ylabel('y axis')ax1.set_zlabel('z axis')plt.show()

注意这里,我们必须定义xyz,然后是 3 个维度的宽度、高度和深度。 这会生成:

image第三十二章 总结

欢迎阅读最后的 Matplotlib 教程。 在这里我们将整理整个系列,并显示一个稍微更复杂的 3D 线框图:

from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as pltimport numpy as npfrom matplotlib import stylestyle.use('ggplot')fig=plt.figure()ax1=fig.add_subplot(111, projection='3d')x, y, z=axes3d.get_test_data()print(axes3d.__file__)ax1.plot_wireframe(x,y,z, rstride=3, cstride=3)ax1.set_xlabel('x axis')ax1.set_ylabel('y axis')ax1.set_zlabel('z axis')plt.show()
image

如果你从一开始就关注这个教程的话,那么你已经学会了 Matplotlib 提供的大部分内容。 你可能不相信,但Matplotlib 仍然可以做很多其他的事情! 请继续学习,你可以随时访问 Matplotlib.org,并查看示例和图库页面。

如果你发现自己大量使用 Matplotlib,请考虑捐助给 John Hunter Memorial 基金。

注:空间曲面的画法

# 二次抛物面 z=x^2 + y^2x=np.linspace(-10, 10, 101)y=xx, y=np.meshgrid(x, y)z=x ** 2 + y ** 2ax=plot.subplot(111, projection='3d')ax.plot_wireframe(x, y, z)plot.show()
image
# 半径为 1 的球t=np.linspace(0, np.pi * 2, 100)s=np.linspace(0, np.pi, 100)t, s=np.meshgrid(t, s)x=np.cos(t) * np.sin(s)y=np.sin(t) * np.sin(s)z=np.cos(s)ax=plot.subplot(111, projection='3d')ax.plot_wireframe(x, y, z)plot.show()
image

参考链接

Python入门实战:一文看懂用Matplotlib实现数据可视化

导读:获取数据之后,而不知道如何查看数据,用途还是有限的。幸好,我们有Matplotlib!

Matplotlib 是基于 NumPy 数组构建的多平台数据可视化库。它是John Hunter 在2002年构想的,原本的设计是给 IPython 打补丁,让命令行中也可以有交互式的 MATLAB 风格的画图工具。

在近些年,已经出现了更新更好的工具最终替代了 Matplotlib(比如 R 语言中的ggplot和ggvis), 但 Matplotlib 依旧是一个经过良好测试的、跨平台的图形引擎。

作者:迈克尔·贝耶勒(Michael Beyeler)

如需转载请联系华章科技

01 引入 Matplotlib

如果已安装Anaconda Python版本,就已经安装好了可以使用的 Matplotlib。否则,可能要访问官网并从中获取安装说明:

http://matplotlib.org

正如使用np作为 NumPy 的缩写,我们将使用一些标准的缩写来表示 Matplotlib 的引入:

In [1]: import matplotlib as mplIn [2]: import matplotlib.pyplot as plt

在本书中,plt接口会被频繁使用。

02 生成一个简单的绘图

让我们创建第一个绘图。

假设想要画出正弦函数sin(x)的线性图。得到函数在x坐标轴上0≤x<10内所有点的值。我们将使用 NumPy 中的 linspace 函数来在x坐标轴上创建一个从0到10的线性空间,以及100个采样点:

In [3]: import numpy as npIn [4]: x=np.linspace(0, 10, 100)

可以使用 NumPy 中的sin函数得到所有x点的值,并通过调用plt中的plot函数把结果画出来:

In [5]: plt.plot(x, np.sin(x))

你亲自尝试了吗?发生了什么吗?有没有什么东西出现?

实际情况是,取决于你在哪里运行脚本,可能无法看到任何东西。有下面几种可能性:

1. 从.py脚本中绘图

如果从一个脚本中运行 Matplotlib,需要加上下面的这行调用:

plt.show()

在脚本末尾调用这个函数,你的绘图就会出现!

2. 从 IPython shell 中绘图

这实际上是交互式地执行Matplotlib最方便的方式。为了让绘图出现,需要在启动 IPython 后使用所谓的%matplotlib魔法命令。

In [1]: %matplotlibUsing matplotlib backend: TkAggIn [2]: import matplotlib.pyplot as plt

接下来,无须每次调用plt.show()函数,所有的绘图将会自动出现。

3. 从 Jupyter Notebook 中绘图

如果你是从基于浏览器的 Jupyter Notebook 中看这段代码,需要使用同样的%matplotlib魔法命令。然而,也可以直接在notebook中嵌入图形,这会有两种输出选项:

%matplotlib notebook 将会把交互式的绘图嵌入到notebook中%matplotlib inline 将会把绘图的静态图像嵌入到notebook中

在本书中,将会使用inline选项:

In [6]: %matplotlib inline

现在再次尝试一下:

In [7]: plt.plot(x, np.sin(x))Out[7]: [<matplotlib.lines.Line2D at 0x7f3aac426eb8>]

上面的命令会得到下面的绘图输出结果:

▲使用 Matplotlib 绘制正弦函数图像

如果想要把绘图保存下来留作以后使用,可以直接在 IPython 或者 Jupyter Notebook 使用下面的命令保存:

In [8]: plt.savefig('figures/02.03-sine.png')

仅需要确保你使用了支持的文件后缀,比如.jpg、.png、.tif、.svg、.eps或者.pdf。

Tips:可以在引入Matplotlib后通过运行plt.style.use(style_name)来修改绘图的风格。所有可用的风格在plt.style.available中列出。比如,尝试使用plt.style.use('fivethirtyeight')、plt.style.use('ggplot')或者plt.style.use('seaborn-dark')。为了更好玩,可以运行plt.xkcd(),然后尝试绘制一些别的图形。

03 可视化外部数据集的数据

作为本章最后一个测试,让我们对外部数据集进行可视化,比如scikit-learn中的数字数据集。

为此,需要三个可视化工具:

scikit-learn用于获取实际数据NumPy 用于数据再加工Matplotlib

那么开始引入这些包吧:

In [1]: import numpy as np... from sklearn import datasets... import matplotlib.pyplot as plt... % matplotlib inline

第一步是载入实际数据:

In [2]: digits=datasets.load_digits()

如果没记错的话,digits应该有两个不同的数据域:data域包含了真正的图像数据,target域包含了图像的标签。相对于相信我们的记忆,我们还是应该对digits稍加探索。输入它的名字,添加一个点号,然后按Tab键:digits.<TAB>,这个操作将向我们展示digits也包含了一些其他的域,比如一个名为images的域。images和data这两个域,似乎简单从形状上就可以区分。

In [3]: print(digits.data.shape)... print(digits.images.shape)Out[3]: (1797, 64)... (1797, 8, 8)

两种情况中,第一维对应的都是数据集中的图像数量。然而,data中所有像素都在一个大的向量中排列,而images保留了各个图像8×8的空间排列。

因此,如果想要绘制出一副单独的图像,使用images将更加合适。首先,使用NumPy的数组切片从数据集中获取一幅图像:

In [4]: img=digits.images[0, :, :]

这里是从1797个元素的数组中获取了它的第一行数据,这行数据对应的是8×8=64个像素。下面就可以使用plt中的imshow函数来绘制这幅图像:

In [5]: plt.imshow(img, cmap='gray')Out[5]: <matplotlib.image.AxesImage at 0x7efcd27f30f0>

上面的命令得到下面的输出:

▲数字数据集中的一个图像样例

此外,这里也使用cmap参数指定了一个颜色映射。默认情况下,Matplotlib 使用MATLAB默认的颜色映射jet。然而,在灰度图像的情况下,gray颜色映射更有效。

最后,可以使用plt的subplot函数绘制全部数字的样例。subplot函数与MATLAB中的函数一样,需要指定行数、列数以及当前的子绘图索引(从1开始计算)。我们将使用for 循环在数据集中迭代出前十张图像,每张图像都分配到一个单独的子绘图中。

In [6]: for image_index in range(10):... # 图像按0开始索引,子绘图按1开始索引... subplot_index=image_index + 1... plt.subplot(2, 5, subplot_index)... plt.imshow(digits.images[image_index, :, :], cmap='gray')

这会得到下面的输出结果:

▲数字数据集中的十幅样例图像

Tips:另一个拥有各种数据集资源的是我的母校——加利福尼亚大学欧文分校的机器学习仓库:

http://archive.ics.uci.edu/ml

关于作者:Michael Beyeler,华盛顿大学神经工程和数据科学专业的博士后,主攻仿生视觉计算模型,用以为盲人植入人工视网膜(仿生眼睛),改善盲人的视觉体验。 他的工作属于神经科学、计算机工程、计算机视觉和机器学习的交叉领域。同时他也是多个开源项目的积极贡献者。

本文摘编自《机器学习:使用OpenCV和Python进行智能图像处理》,经出版方授权发布。

延伸阅读《机器学习》

推荐语:OpenCV是一个综合了经典和先进计算机视觉、机器学习算法的开源库。通过与Python Anaconda版本结合,你就可以获取你所需要的所有开源计算库。 本书首先介绍分类和回归等统计学习的基本概念,然后详细讲解决策树、支持向量机和贝叶斯网络等算法,以及如何把它们与其他OpenCV函数结合。

发表评论