python:matplotlib数据可视化(下)
python:matplotlib数据可视化(上)
第十八章 注解股票图表的最后价格在这个 Matplotlib 教程中,我们将展示如何跟踪股票的最后价格的示例,通过将其注解到轴域的右侧,就像许多图表应用程序会做的那样。
虽然人们喜欢在他们的实时图表中看到历史价格,他们也想看到最新的价格。 大多数应用程序做的是,在价格的y
轴高度处注释最后价格,然后突出显示它,并在价格变化时,在框中将其略微移动。 使用我们最近学习的注解教程,我们可以添加一个bbox
。
我们的核心代码是:
bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1)ax1.annotate(str(closep[-1]), (date[-1], closep[-1]), xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)
我们使用ax1.annotate
来放置最后价格的字符串值。 我们不在这里使用它,但我们将要注解的点指定为图上最后一个点。 接下来,我们使用xytext
将我们的文本放置到特定位置。 我们将它的y
坐标指定为最后一个点的y
坐标,x
坐标指定为最后一个点的x
坐标,再加上几个点。我们这样做是为了将它移出图表。 将文本放在图形外面就足够了,但现在它只是一些浮动文本。
我们使用bbox
参数在文本周围创建一个框。 我们使用bbox_props
创建一个属性字典,包含盒子样式,然后是白色(w
)前景色,黑色(k
)边框颜色并且线宽为 1。 更多框样式请参阅 matplotlib 注解文档。
最后,这个注解向右移动,需要我们使用subplots_adjust
来创建一些新空间:
plt.subplots_adjust(left=0.11, bottom=0.24, right=0.87, top=0.90, wspace=0.2, hspace=0)
这里的完整代码如下:
import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)def bytespdate2num(fmt, encoding='utf-8'): strconverter=mdates.strpdate2num(fmt) def bytesconverter(b): s=b.decode(encoding) return strconverter(s) return bytesconverterdef graph_data(stock): fig=plt.figure() ax1=plt.subplot2grid((1,1), (0,0)) stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1m/csv' source_code=urllib.request.urlopen(stock_price_url).read().decode() stock_data=[] split_source=source_code.split('\n') for line in split_source: split_line=line.split(',') if len(split_line)==6: if 'values' not in line and 'labels' not in line: stock_data.append(line) date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data, delimiter=',', unpack=True, converters={0: bytespdate2num('%Y%m%d')}) x=0 y=len(date) ohlc=[] while x < y: append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x] ohlc.append(append_me) x+=1 candlestick_ohlc(ax1, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f') for label in ax1.xaxis.get_ticklabels(): label.set_rotation(45) ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax1.xaxis.set_major_locator(mticker.MaxNLocator(10)) ax1.grid(True) bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1) ax1.annotate(str(closep[-1]), (date[-1], closep[-1]), xytext=(date[-1]+3, closep[-1]), bbox=bbox_props)## # Annotation example with arrow## ax1.annotate('Bad News!',(date[11],highp[11]),## xytext=(0.8, 0.9), textcoords='axes fraction',## arrowprops=dict(facecolor='grey',color='grey'))#### ## # Font dict example## font_dict={'family':'serif',## 'color':'darkred',## 'size':15}## # Hard coded text## ax1.text(date[10], closep[1],'Text Example', fontdict=font_dict) plt.xlabel('Date') plt.ylabel('Price') plt.title(stock) #plt.legend() plt.subplots_adjust(left=0.11, bottom=0.24, right=0.87, top=0.90, wspace=0.2, hspace=0) plt.show()graph_data('EBAY')
结果为:
image第十九章 子图在这个 Matplotlib 教程中,我们将讨论子图。 有两种处理子图的主要方法,用于在同一图上创建多个图表。 现在,我们将从一个干净的代码开始。 如果你一直关注这个教程,那么请确保保留旧的代码,或者你可以随时重新查看上一个教程的代码。
首先,让我们使用样式,创建我们的图表,然后创建一个随机创建示例绘图的函数:
import randomimport matplotlib.pyplot as pltfrom matplotlib import stylestyle.use('fivethirtyeight')fig=plt.figure()def create_plots(): xs=[] ys=[] for i in range(10): x=i y=random.randrange(10) xs.append(x) ys.append(y) return xs, ys
现在,我们开始使用add_subplot
方法创建子图:
ax1=fig.add_subplot(221)ax2=fig.add_subplot(222)ax3=fig.add_subplot(212)
它的工作原理是使用 3 个数字,即:行数(numRows
)、列数(numCols
)和绘图编号(plotNum
)。
所以,221 表示两行两列的第一个位置。222 是两行两列的第二个位置。最后,212 是两行一列的第二个位置。
2x2:+-----+-----+| 1 | 2 |+-----+-----+| 3 | 4 |+-----+-----+2x1:+-----------+| 1 |+-----------+| 2 |+-----------+
译者注:原文此处表述有误,译文已更改。
译者注:
221
是缩写形式,仅在行数乘列数小于 10 时有效,否则要写成2,2,1
。
此代码结果为:
image这就是add_subplot
。 尝试一些你认为可能很有趣的配置,然后尝试使用add_subplot
创建它们,直到你感到满意。
接下来,让我们介绍另一种方法,它是subplot2grid
。
删除或注释掉其他轴域定义,然后添加:
ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)
所以,add_subplot
不能让我们使一个绘图覆盖多个位置。 但是这个新的subplot2grid
可以。 所以,subplot2grid
的工作方式是首先传递一个元组,它是网格形状。 我们传递了(6,1)
,这意味着整个图表分为六行一列。 下一个元组是左上角的起始点。 对于ax1
,这是0,0
,因此它起始于顶部。 接下来,我们可以选择指定rowspan
和colspan
。 这是轴域所占的行数和列数。
6x1: colspan=1(0,0) +-----------+ | ax1 | rowspan=1(1,0) +-----------+ | | | ax2 | rowspan=4 | | | |(5,0) +-----------+ | ax3 | rowspan=1 +-----------+
结果为:
image显然,我们在这里有一些重叠的问题,我们可以调整子图来处理它。
再次,尝试构思各种配置的子图,使用subplot2grid
制作出来,直到你感到满意!
我们将继续使用subplot2grid
,将它应用到我们已经逐步建立的代码中,我们将在下一个教程中继续。
在这篇 Matplotlib 教程中,我们介绍了添加一些简单的函数来计算数据,以便我们填充我们的轴域。 一个是简单的移动均值,另一个是简单的价格 HML 计算。
这些新函数是:
def moving_average(values, window): weights=np.repeat(1.0, window)/window smas=np.convolve(values, weights, 'valid') return smasdef high_minus_low(highs, lows): return highs-lows
你不需要太过专注于理解移动均值的工作原理,我们只是对样本数据来计算它,以便可以学习更多自定义 Matplotlib 的东西。
我们还想在脚本顶部为移动均值定义一些值:
MA1=10MA2=30
下面,在我们的graph_data
函数中:
ma1=moving_average(closep,MA1)ma2=moving_average(closep,MA2)start=len(date[MA2-1:])h_l=list(map(high_minus_low, highp, lowp))
在这里,我们计算两个移动均值和 HML。
我们还定义了一个『起始』点。 我们这样做是因为我们希望我们的数据排成一行。 例如,20 天的移动均值需要 20 个数据点。 这意味着我们不能在第 5 天真正计算 20 天的移动均值。 因此,当我们计算移动均值时,我们会失去一些数据。 为了处理这种数据的减法,我们使用起始变量来计算应该有多少数据。 这里,我们可以安全地使用[-start:]
绘制移动均值,并且如果我们希望的话,对所有绘图进行上述步骤来排列数据。
接下来,我们可以在ax1
上绘制 HML,通过这样:
ax1.plot_date(date,h_l,'-')
最后我们可以通过这样向ax3
添加移动均值:
ax3.plot(date[-start:], ma1[-start:])ax3.plot(date[-start:], ma2[-start:])
我们的完整代码,包括增加我们所用的时间范围:
import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window): weights=np.repeat(1.0, window)/window smas=np.convolve(values, weights, 'valid') return smasdef high_minus_low(highs, lows): return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'): strconverter=mdates.strpdate2num(fmt) def bytesconverter(b): s=b.decode(encoding) return strconverter(s) return bytesconverterdef graph_data(stock): fig=plt.figure() ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1) plt.title(stock) ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1) plt.xlabel('Date') plt.ylabel('Price') ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1) stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv' source_code=urllib.request.urlopen(stock_price_url).read().decode() stock_data=[] split_source=source_code.split('\n') for line in split_source: split_line=line.split(',') if len(split_line)==6: if 'values' not in line and 'labels' not in line: stock_data.append(line) date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data, delimiter=',', unpack=True, converters={0: bytespdate2num('%Y%m%d')}) x=0 y=len(date) ohlc=[] while x < y: append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x] ohlc.append(append_me) x+=1 ma1=moving_average(closep,MA1) ma2=moving_average(closep,MA2) start=len(date[MA2-1:]) h_l=list(map(high_minus_low, highp, lowp)) ax1.plot_date(date,h_l,'-') candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f') for label in ax2.xaxis.get_ticklabels(): label.set_rotation(45) ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax2.xaxis.set_major_locator(mticker.MaxNLocator(10)) ax2.grid(True) bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1) ax2.annotate(str(closep[-1]), (date[-1], closep[-1]), xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)## # Annotation example with arrow## ax2.annotate('Bad News!',(date[11],highp[11]),## xytext=(0.8, 0.9), textcoords='axes fraction',## arrowprops=dict(facecolor='grey',color='grey'))#### ## # Font dict example## font_dict={'family':'serif',## 'color':'darkred',## 'size':15}## # Hard coded text ## ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict) ax3.plot(date[-start:], ma1[-start:]) ax3.plot(date[-start:], ma2[-start:]) plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0) plt.show()graph_data('EBAY')
代码效果如图:
image第二十二章 自定义填充、修剪和清除欢迎阅读另一个 Matplotlib 教程! 在本教程中,我们将清除图表,然后再做一些自定义。
我们当前的代码是:
import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window): weights=np.repeat(1.0, window)/window smas=np.convolve(values, weights, 'valid') return smasdef high_minus_low(highs, lows): return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'): strconverter=mdates.strpdate2num(fmt) def bytesconverter(b): s=b.decode(encoding) return strconverter(s) return bytesconverterdef graph_data(stock): fig=plt.figure() ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1) plt.title(stock) ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1) plt.xlabel('Date') plt.ylabel('Price') ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1) stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv' source_code=urllib.request.urlopen(stock_price_url).read().decode() stock_data=[] split_source=source_code.split('\n') for line in split_source: split_line=line.split(',') if len(split_line)==6: if 'values' not in line and 'labels' not in line: stock_data.append(line) date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data, delimiter=',', unpack=True, converters={0: bytespdate2num('%Y%m%d')}) x=0 y=len(date) ohlc=[] while x < y: append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x] ohlc.append(append_me) x+=1 ma1=moving_average(closep,MA1) ma2=moving_average(closep,MA2) start=len(date[MA2-1:]) h_l=list(map(high_minus_low, highp, lowp)) ax1.plot_date(date,h_l,'-') candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f') for label in ax2.xaxis.get_ticklabels(): label.set_rotation(45) ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax2.xaxis.set_major_locator(mticker.MaxNLocator(10)) ax2.grid(True) bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1) ax2.annotate(str(closep[-1]), (date[-1], closep[-1]), xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)## # Annotation example with arrow## ax2.annotate('Bad News!',(date[11],highp[11]),## xytext=(0.8, 0.9), textcoords='axes fraction',## arrowprops=dict(facecolor='grey',color='grey'))#### ## # Font dict example## font_dict={'family':'serif',## 'color':'darkred',## 'size':15}## # Hard coded text ## ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict) ax3.plot(date[-start:], ma1[-start:]) ax3.plot(date[-start:], ma2[-start:]) plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0) plt.show()graph_data('EBAY')
现在我认为向我们的移动均值添加自定义填充是一个很好的主意。 移动均值通常用于说明价格趋势。 这个想法是,你可以计算一个快速和一个慢速的移动均值。 一般来说,移动均值用于使价格变得『平滑』。 他们总是『滞后』于价格,但是我们的想法是计算不同的速度。 移动均值越大就越『慢』。 所以这个想法是,如果『较快』的移动均值超过『较慢』的均值,那么价格就会上升,这是一件好事。 如果较快的 MA 从较慢的 MA 下方穿过,则这是下降趋势并且通常被视为坏事。 我的想法是在快速和慢速 MA 之间填充,『上升』趋势为绿色,然后下降趋势为红色。 方法如下:
ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] < ma2[-start:]), facecolor='r', edgecolor='r', alpha=0.5)ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] > ma2[-start:]), facecolor='g', edgecolor='g', alpha=0.5)
下面,我们会碰到一些我们可解决的问题:
ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))ax3.xaxis.set_major_locator(mticker.MaxNLocator(10))for label in ax3.xaxis.get_ticklabels(): label.set_rotation(45)plt.setp(ax1.get_xticklabels(), visible=False)plt.setp(ax2.get_xticklabels(), visible=False)
这里,我们剪切和粘贴ax2
日期格式,然后我们将x
刻度标签设置为false
,去掉它们!
我们还可以通过在轴域定义中执行以下操作,为每个轴域提供自定义标签:
fig=plt.figure()ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)plt.title(stock)ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1)plt.xlabel('Date')plt.ylabel('Price')ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1)
接下来,我们可以看到,我们y
刻度有许多数字,经常互相覆盖。 我们也看到轴之间互相重叠。 我们可以这样:
ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5, prune='lower'))
所以,这里发生的是,我们通过首先将nbins
设置为 5 来修改我们的y
轴对象。这意味着我们显示的标签最多为 5 个。然后我们还可以『修剪』标签,因此,在我们这里, 我们修剪底部标签,这会使它消失,所以现在不会有任何文本重叠。 我们仍然可能打算修剪ax2
的顶部标签,但这里是我们目前为止的源代码:
当前的源码:
import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window): weights=np.repeat(1.0, window)/window smas=np.convolve(values, weights, 'valid') return smasdef high_minus_low(highs, lows): return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'): strconverter=mdates.strpdate2num(fmt) def bytesconverter(b): s=b.decode(encoding) return strconverter(s) return bytesconverterdef graph_data(stock): fig=plt.figure() ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1) plt.title(stock) plt.ylabel('H-L') ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1) plt.ylabel('Price') ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1) plt.ylabel('MAvgs') stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv' source_code=urllib.request.urlopen(stock_price_url).read().decode() stock_data=[] split_source=source_code.split('\n') for line in split_source: split_line=line.split(',') if len(split_line)==6: if 'values' not in line and 'labels' not in line: stock_data.append(line) date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data, delimiter=',', unpack=True, converters={0: bytespdate2num('%Y%m%d')}) x=0 y=len(date) ohlc=[] while x < y: append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x] ohlc.append(append_me) x+=1 ma1=moving_average(closep,MA1) ma2=moving_average(closep,MA2) start=len(date[MA2-1:]) h_l=list(map(high_minus_low, highp, lowp)) ax1.plot_date(date,h_l,'-') ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=5, prune='lower')) candlestick_ohlc(ax2, ohlc, width=0.4, colorup='#77d879', colordown='#db3f3f') ax2.grid(True) bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1) ax2.annotate(str(closep[-1]), (date[-1], closep[-1]), xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)## # Annotation example with arrow## ax2.annotate('Bad News!',(date[11],highp[11]),## xytext=(0.8, 0.9), textcoords='axes fraction',## arrowprops=dict(facecolor='grey',color='grey'))#### ## # Font dict example## font_dict={'family':'serif',## 'color':'darkred',## 'size':15}## # Hard coded text ## ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict) ax3.plot(date[-start:], ma1[-start:], linewidth=1) ax3.plot(date[-start:], ma2[-start:], linewidth=1) ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] < ma2[-start:]), facecolor='r', edgecolor='r', alpha=0.5) ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] > ma2[-start:]), facecolor='g', edgecolor='g', alpha=0.5) ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax3.xaxis.set_major_locator(mticker.MaxNLocator(10)) for label in ax3.xaxis.get_ticklabels(): label.set_rotation(45) plt.setp(ax1.get_xticklabels(), visible=False) plt.setp(ax2.get_xticklabels(), visible=False) plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0) plt.show()graph_data('EBAY')
image看起来好了一些,但是仍然有一些东西需要清除。
第二十三章 共享 X 轴在这个 Matplotlib 数据可视化教程中,我们将讨论sharex
选项,它允许我们在图表之间共享x
轴。将sharex
看做『复制 x』也许更好。
在我们开始之前,首先我们要做些修剪并在另一个轴上设置最大刻度数,如下所示:
ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper'))
以及
ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper'))
现在,让我们共享所有轴域之间的x
轴。 为此,我们需要将其添加到轴域定义中:
fig=plt.figure()ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1)plt.title(stock)plt.ylabel('H-L')ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1)plt.ylabel('Price')ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1)plt.ylabel('MAvgs')
上面,对于ax2
和ax3
,我们添加一个新的参数,称为sharex
,然后我们说,我们要与ax1
共享x
轴。
使用这种方式,我们可以加载图表,然后我们可以放大到一个特定的点,结果将是这样:
image所以这意味着所有轴域沿着它们的x
轴一起移动。 这很酷吧!
接下来,让我们将[-start:]
应用到所有数据,所以所有轴域都起始于相同地方。 我们最终的代码为:
import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window): weights=np.repeat(1.0, window)/window smas=np.convolve(values, weights, 'valid') return smasdef high_minus_low(highs, lows): return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'): strconverter=mdates.strpdate2num(fmt) def bytesconverter(b): s=b.decode(encoding) return strconverter(s) return bytesconverterdef graph_data(stock): fig=plt.figure() ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1) plt.title(stock) plt.ylabel('H-L') ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1) plt.ylabel('Price') ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1) plt.ylabel('MAvgs') stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv' source_code=urllib.request.urlopen(stock_price_url).read().decode() stock_data=[] split_source=source_code.split('\n') for line in split_source: split_line=line.split(',') if len(split_line)==6: if 'values' not in line and 'labels' not in line: stock_data.append(line) date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data, delimiter=',', unpack=True, converters={0: bytespdate2num('%Y%m%d')}) x=0 y=len(date) ohlc=[] while x < y: append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x] ohlc.append(append_me) x+=1 ma1=moving_average(closep,MA1) ma2=moving_average(closep,MA2) start=len(date[MA2-1:]) h_l=list(map(high_minus_low, highp, lowp)) ax1.plot_date(date[-start:],h_l[-start:],'-') ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower')) candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f') ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper')) ax2.grid(True) bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1) ax2.annotate(str(closep[-1]), (date[-1], closep[-1]), xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)## # Annotation example with arrow## ax2.annotate('Bad News!',(date[11],highp[11]),## xytext=(0.8, 0.9), textcoords='axes fraction',## arrowprops=dict(facecolor='grey',color='grey'))#### ## # Font dict example## font_dict={'family':'serif',## 'color':'darkred',## 'size':15}## # Hard coded text ## ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict) ax3.plot(date[-start:], ma1[-start:], linewidth=1) ax3.plot(date[-start:], ma2[-start:], linewidth=1) ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] < ma2[-start:]), facecolor='r', edgecolor='r', alpha=0.5) ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] > ma2[-start:]), facecolor='g', edgecolor='g', alpha=0.5) ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax3.xaxis.set_major_locator(mticker.MaxNLocator(10)) ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper')) for label in ax3.xaxis.get_ticklabels(): label.set_rotation(45) plt.setp(ax1.get_xticklabels(), visible=False) plt.setp(ax2.get_xticklabels(), visible=False) plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0) plt.show()graph_data('EBAY')
下面我们会讨论如何创建多个y
轴。
在这篇 Matplotlib 教程中,我们将介绍如何在同一子图上使用多个 Y 轴。 在我们的例子中,我们有兴趣在同一个图表及同一个子图上绘制股票价格和交易量。
为此,首先我们需要定义一个新的轴域,但是这个轴域是ax2
仅带有x
轴的『双生子』。
这足以创建轴域了。我们叫它ax2v
,因为这个轴域是ax2
加交易量。
现在,我们在轴域上定义绘图,我们将添加:
ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4)
我们在 0 和当前交易量之间填充,给予它蓝色的前景色,然后给予它一个透明度。 我们想要应用幽冥毒,以防交易量最终覆盖其它东西,所以我们仍然可以看到这两个元素。
所以,到现在为止,我们的代码为:
import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window): weights=np.repeat(1.0, window)/window smas=np.convolve(values, weights, 'valid') return smasdef high_minus_low(highs, lows): return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'): strconverter=mdates.strpdate2num(fmt) def bytesconverter(b): s=b.decode(encoding) return strconverter(s) return bytesconverterdef graph_data(stock): fig=plt.figure() ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1) plt.title(stock) plt.ylabel('H-L') ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1) plt.ylabel('Price') ax2v=ax2.twinx() ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1) plt.ylabel('MAvgs') stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv' source_code=urllib.request.urlopen(stock_price_url).read().decode() stock_data=[] split_source=source_code.split('\n') for line in split_source: split_line=line.split(',') if len(split_line)==6: if 'values' not in line and 'labels' not in line: stock_data.append(line) date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data, delimiter=',', unpack=True, converters={0: bytespdate2num('%Y%m%d')}) x=0 y=len(date) ohlc=[] while x < y: append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x] ohlc.append(append_me) x+=1 ma1=moving_average(closep,MA1) ma2=moving_average(closep,MA2) start=len(date[MA2-1:]) h_l=list(map(high_minus_low, highp, lowp)) ax1.plot_date(date[-start:],h_l[-start:],'-') ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower')) candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f') ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper')) ax2.grid(True) bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1) ax2.annotate(str(closep[-1]), (date[-1], closep[-1]), xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)## # Annotation example with arrow## ax2.annotate('Bad News!',(date[11],highp[11]),## xytext=(0.8, 0.9), textcoords='axes fraction',## arrowprops=dict(facecolor='grey',color='grey'))#### ## # Font dict example## font_dict={'family':'serif',## 'color':'darkred',## 'size':15}## # Hard coded text ## ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict) ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4) ax3.plot(date[-start:], ma1[-start:], linewidth=1) ax3.plot(date[-start:], ma2[-start:], linewidth=1) ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] < ma2[-start:]), facecolor='r', edgecolor='r', alpha=0.5) ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] > ma2[-start:]), facecolor='g', edgecolor='g', alpha=0.5) ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax3.xaxis.set_major_locator(mticker.MaxNLocator(10)) ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper')) for label in ax3.xaxis.get_ticklabels(): label.set_rotation(45) plt.setp(ax1.get_xticklabels(), visible=False) plt.setp(ax2.get_xticklabels(), visible=False) plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0) plt.show()graph_data('GOOG')
会生成:
image太棒了,到目前为止还不错。 接下来,我们可能要删除新y
轴上的标签,然后我们也可能不想让交易量占用太多空间。 没问题:
首先:
ax2v.axes.yaxis.set_ticklabels([])
上面将y
刻度标签设置为一个空列表,所以不会有任何标签了。
译者注:所以将标签删除之后,添加新轴的意义是什么?直接在原轴域上绘图就可以了。
接下来,我们可能要将网格设置为false
,使轴域上不会有双网格:
ax2v.grid(False)
最后,为了处理交易量占用很多空间,我们可以做以下操作:
ax2v.set_ylim(0, 3*volume.max())
所以这设置y
轴显示范围从 0 到交易量的最大值的 3 倍。 这意味着,在最高点,交易量最多可占据图形的33%。 所以,增加volume.max
的倍数越多,空间就越小/越少。
现在,我们的图表为:
import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window): weights=np.repeat(1.0, window)/window smas=np.convolve(values, weights, 'valid') return smasdef high_minus_low(highs, lows): return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'): strconverter=mdates.strpdate2num(fmt) def bytesconverter(b): s=b.decode(encoding) return strconverter(s) return bytesconverterdef graph_data(stock): fig=plt.figure() ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1) plt.title(stock) plt.ylabel('H-L') ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1) plt.ylabel('Price') ax2v=ax2.twinx() ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1) plt.ylabel('MAvgs') stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv' source_code=urllib.request.urlopen(stock_price_url).read().decode() stock_data=[] split_source=source_code.split('\n') for line in split_source: split_line=line.split(',') if len(split_line)==6: if 'values' not in line and 'labels' not in line: stock_data.append(line) date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data, delimiter=',', unpack=True, converters={0: bytespdate2num('%Y%m%d')}) x=0 y=len(date) ohlc=[] while x < y: append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x] ohlc.append(append_me) x+=1 ma1=moving_average(closep,MA1) ma2=moving_average(closep,MA2) start=len(date[MA2-1:]) h_l=list(map(high_minus_low, highp, lowp)) ax1.plot_date(date[-start:],h_l[-start:],'-') ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower')) candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f') ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper')) ax2.grid(True) bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1) ax2.annotate(str(closep[-1]), (date[-1], closep[-1]), xytext=(date[-1]+5, closep[-1]), bbox=bbox_props)## # Annotation example with arrow## ax2.annotate('Bad News!',(date[11],highp[11]),## xytext=(0.8, 0.9), textcoords='axes fraction',## arrowprops=dict(facecolor='grey',color='grey'))#### ## # Font dict example## font_dict={'family':'serif',## 'color':'darkred',## 'size':15}## # Hard coded text ## ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict) ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4) ax2v.axes.yaxis.set_ticklabels([]) ax2v.grid(False) ax2v.set_ylim(0, 3*volume.max()) ax3.plot(date[-start:], ma1[-start:], linewidth=1) ax3.plot(date[-start:], ma2[-start:], linewidth=1) ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] < ma2[-start:]), facecolor='r', edgecolor='r', alpha=0.5) ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] > ma2[-start:]), facecolor='g', edgecolor='g', alpha=0.5) ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax3.xaxis.set_major_locator(mticker.MaxNLocator(10)) ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper')) for label in ax3.xaxis.get_ticklabels(): label.set_rotation(45) plt.setp(ax1.get_xticklabels(), visible=False) plt.setp(ax2.get_xticklabels(), visible=False) plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0) plt.show()graph_data('GOOG')
到这里,我们差不多完成了。 这里唯一的缺陷是一个好的图例。 一些线条是显而易见的,但人们可能会好奇移动均值的参数是什么,我们这里是 10 和 30。 添加自定义图例是下一个教程中涉及的内容。
第二十五章 自定义图例在这篇 Matplotlib 教程中,我们将讨论自定义图例。 我们已经介绍了添加图例的基础知识。
图例的主要问题通常是图例阻碍了数据的展示。 这里有几个选项。 一个选项是将图例放在轴域外,但是我们在这里有多个子图,这是非常困难的。 相反,我们将使图例稍微小一点,然后应用一个透明度。
首先,为了创建一个图例,我们需要向我们的数据添加我们想要显示在图例上的标签。
ax1.plot_date(date[-start:],h_l[-start:],'-', label='H-L')...ax2v.plot([],[], color='#0079a3', alpha=0.4, label='Volume')...ax3.plot(date[-start:], ma1[-start:], linewidth=1, label=(str(MA1)+'MA'))ax3.plot(date[-start:], ma2[-start:], linewidth=1, label=(str(MA2)+'MA'))
请注意,我们通过创建空行为交易量添加了标签。 请记住,我们不能对任何填充应用标签,所以这就是我们添加这个空行的原因。
现在,我们可以在右下角添加图例,通过在plt.show()
之前执行以下操作:
ax1.legend()ax2v.legend()ax3.legend()
会生成:
image所以,我们可以看到,图例还是占用了一些位置。 让我们更改位置,大小并添加透明度:
ax1.legend()leg=ax1.legend(loc=9, ncol=2,prop={'size':11})leg.get_frame().set_alpha(0.4)ax2v.legend()leg=ax2v.legend(loc=9, ncol=2,prop={'size':11})leg.get_frame().set_alpha(0.4)ax3.legend()leg=ax3.legend(loc=9, ncol=2,prop={'size':11})leg.get_frame().set_alpha(0.4)
所有的图例位于位置 9(上中间)。 有很多地方可放置图例,我们可以为参数传入不同的位置号码,来看看它们都位于哪里。 ncol
参数允许我们指定图例中的列数。 这里只有一列,如果图例中有 2 个项目,他们将堆叠在一列中。 最后,我们将尺寸规定为更小。 之后,我们对整个图例应用0.4
的透明度。
现在我们的结果为:
image完整的代码为:
import matplotlib.pyplot as pltimport matplotlib.dates as mdatesimport matplotlib.ticker as mtickerfrom matplotlib.finance import candlestick_ohlcfrom matplotlib import styleimport numpy as npimport urllibimport datetime as dtstyle.use('fivethirtyeight')print(plt.style.available)print(plt.__file__)MA1=10MA2=30def moving_average(values, window): weights=np.repeat(1.0, window)/window smas=np.convolve(values, weights, 'valid') return smasdef high_minus_low(highs, lows): return highs-lowsdef bytespdate2num(fmt, encoding='utf-8'): strconverter=mdates.strpdate2num(fmt) def bytesconverter(b): s=b.decode(encoding) return strconverter(s) return bytesconverterdef graph_data(stock): fig=plt.figure(facecolor='#f0f0f0') ax1=plt.subplot2grid((6,1), (0,0), rowspan=1, colspan=1) plt.title(stock) plt.ylabel('H-L') ax2=plt.subplot2grid((6,1), (1,0), rowspan=4, colspan=1, sharex=ax1) plt.ylabel('Price') ax2v=ax2.twinx() ax3=plt.subplot2grid((6,1), (5,0), rowspan=1, colspan=1, sharex=ax1) plt.ylabel('MAvgs') stock_price_url='http://chartapi.finance.yahoo.com/instrument/1.0/'+stock+'/chartdata;type=quote;range=1y/csv' source_code=urllib.request.urlopen(stock_price_url).read().decode() stock_data=[] split_source=source_code.split('\n') for line in split_source: split_line=line.split(',') if len(split_line)==6: if 'values' not in line and 'labels' not in line: stock_data.append(line) date, closep, highp, lowp, openp, volume=np.loadtxt(stock_data, delimiter=',', unpack=True, converters={0: bytespdate2num('%Y%m%d')}) x=0 y=len(date) ohlc=[] while x < y: append_me=date[x], openp[x], highp[x], lowp[x], closep[x], volume[x] ohlc.append(append_me) x+=1 ma1=moving_average(closep,MA1) ma2=moving_average(closep,MA2) start=len(date[MA2-1:]) h_l=list(map(high_minus_low, highp, lowp)) ax1.plot_date(date[-start:],h_l[-start:],'-', label='H-L') ax1.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='lower')) candlestick_ohlc(ax2, ohlc[-start:], width=0.4, colorup='#77d879', colordown='#db3f3f') ax2.yaxis.set_major_locator(mticker.MaxNLocator(nbins=7, prune='upper')) ax2.grid(True) bbox_props=dict(boxstyle='round',fc='w', ec='k',lw=1) ax2.annotate(str(closep[-1]), (date[-1], closep[-1]), xytext=(date[-1]+4, closep[-1]), bbox=bbox_props)## # Annotation example with arrow## ax2.annotate('Bad News!',(date[11],highp[11]),## xytext=(0.8, 0.9), textcoords='axes fraction',## arrowprops=dict(facecolor='grey',color='grey'))#### ## # Font dict example## font_dict={'family':'serif',## 'color':'darkred',## 'size':15}## # Hard coded text ## ax2.text(date[10], closep[1],'Text Example', fontdict=font_dict) ax2v.plot([],[], color='#0079a3', alpha=0.4, label='Volume') ax2v.fill_between(date[-start:],0, volume[-start:], facecolor='#0079a3', alpha=0.4) ax2v.axes.yaxis.set_ticklabels([]) ax2v.grid(False) ax2v.set_ylim(0, 3*volume.max()) ax3.plot(date[-start:], ma1[-start:], linewidth=1, label=(str(MA1)+'MA')) ax3.plot(date[-start:], ma2[-start:], linewidth=1, label=(str(MA2)+'MA')) ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] < ma2[-start:]), facecolor='r', edgecolor='r', alpha=0.5) ax3.fill_between(date[-start:], ma2[-start:], ma1[-start:], where=(ma1[-start:] > ma2[-start:]), facecolor='g', edgecolor='g', alpha=0.5) ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) ax3.xaxis.set_major_locator(mticker.MaxNLocator(10)) ax3.yaxis.set_major_locator(mticker.MaxNLocator(nbins=4, prune='upper')) for label in ax3.xaxis.get_ticklabels(): label.set_rotation(45) plt.setp(ax1.get_xticklabels(), visible=False) plt.setp(ax2.get_xticklabels(), visible=False) plt.subplots_adjust(left=0.11, bottom=0.24, right=0.90, top=0.90, wspace=0.2, hspace=0) ax1.legend() leg=ax1.legend(loc=9, ncol=2,prop={'size':11}) leg.get_frame().set_alpha(0.4) ax2v.legend() leg=ax2v.legend(loc=9, ncol=2,prop={'size':11}) leg.get_frame().set_alpha(0.4) ax3.legend() leg=ax3.legend(loc=9, ncol=2,prop={'size':11}) leg.get_frame().set_alpha(0.4) plt.show() fig.savefig('google.png', facecolor=fig.get_facecolor())graph_data('GOOG')
现在我们可以看到图例,但也看到了图例下的任何信息。 还要注意额外函数fig.savefig
。 这是自动保存图形的图像的方式。 我们还可以设置所保存的图形的前景色,使背景不是白色的,如我们的例子所示。
这就是目前为止,我想要显示的典型 Matplotlib 图表。 接下来,我们将涉及Basemap
,它是一个 Matplotlib 扩展,用于绘制地理位置,然后我打算讲解 Matplotlib 中的 3D 图形。
在这个 Matplotlib 教程中,我们将涉及地理绘图模块Basemap
。 Basemap
是 Matplotlib 的扩展。
为了使用Basemap
,我们首先需要安装它。 为了获得Basemap
,你可以从这里获取:http://matplotlib.org/basemap/users/download.html,或者你可以访问http://www.lfd.uci.edu/~gohlke/pythonlibs/。
如果你在安装Basemap
时遇到问题,请查看pip
安装教程。
一旦你安装了Basemap
,你就可以创建地图了。 首先,让我们投影一个简单的地图。 为此,我们需要导入Basemap
,pyplot
,创建投影,至少绘制某种轮廓或数据,然后我们可以显示图形。
from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill')m.drawcoastlines()plt.show()
上面的代码结果如下:
image这是使用 Miller 投影完成的,这只是许多Basemap
投影选项之一。
在这篇 Matplotlib 教程中,我们继续使用Basemap
地理绘图扩展。 我们将展示一些我们可用的自定义选项。
首先,从上一个教程中获取我们的起始代码:
from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill')m.drawcoastlines()plt.show()
我们可以从放大到特定区域来开始:
from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill', llcrnrlat=-40, llcrnrlon=-40, urcrnrlat=50, urcrnrlon=75)m.drawcoastlines()plt.show()
这里的参数是:
llcrnrlat
- 左下角的纬度llcrnrlon
- 左下角的经度urcrnrlat
- 右上角的纬度urcrnrlon
- 右上角的经度此外,坐标需要转换,其中西经和南纬坐标是负值,北纬和东经坐标是正值。
使用这些坐标,Basemap
会选择它们之间的区域。
下面,我们要使用一些东西,类似:
m.drawcountries(linewidth=2)
这会画出国家,并使用线宽为 2 的线条生成分界线。
另一个选项是:
m.drawstates(color='b')
这会用蓝色线条画出州。
你也可以执行:
m.drawcounties(color='darkred')
这会画出国家。
所以,我们的代码是:
from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill', llcrnrlat=-90, llcrnrlon=-180, urcrnrlat=90, urcrnrlon=180)m.drawcoastlines()m.drawcountries(linewidth=2)m.drawstates(color='b')m.drawcounties(color='darkred')plt.title('Basemap Tutorial')plt.show()
image很难说,但我们定义了美国的区县的线条。 我们可以使用放大镜放大Basemap
图形,就像其他图形那样,会生成:
另一个有用的选项是Basemap
调用中的『分辨率』选项。
m=Basemap(projection='mill', llcrnrlat=-90, llcrnrlon=-180, urcrnrlat=90, urcrnrlon=180, resolution='l')
分辨率的选项为:
c
- 粗糙l
- 低h
- 高f
- 完整对于更高的分辨率,你应该放大到很大,否则这可能只是浪费。
另一个选项是使用etopo()
绘制地形,如:
m.etopo()
使用drawcountries
方法绘制此图形会生成:
[图片上传失败...(image-eaca82-1558445064342)]
最后,有一个蓝色的大理石版本,你可以调用:
m.bluemarble()
会生成:
[图片上传失败...(image-3f9978-1558445064342)]
目前为止的代码:
from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill', llcrnrlat=-90, llcrnrlon=-180, urcrnrlat=90, urcrnrlon=180, resolution='l')m.drawcoastlines()m.drawcountries(linewidth=2)##m.drawstates(color='b')##m.drawcounties(color='darkred')#m.fillcontinents()#m.etopo()m.bluemarble()plt.title('Basemap Tutorial')plt.show()
第二十八章 在 Basemap 中绘制坐标欢迎阅读另一个 Matplotlib Basemap 教程。 在本教程中,我们将介绍如何绘制单个坐标,以及如何在地理区域中连接这些坐标。
首先,我们将从一些基本的起始数据开始:
from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill', llcrnrlat=25, llcrnrlon=-130, urcrnrlat=50, urcrnrlon=-60, resolution='l')m.drawcoastlines()m.drawcountries(linewidth=2)m.drawstates(color='b')
接下来,我们可以绘制坐标,从获得它们的实际坐标开始。 记住,南纬和西经坐标需要转换为负值。 例如,纽约市是北纬40.7127
西经74.0059
。 我们可以在我们的程序中定义这些坐标,如:
NYClat, NYClon=40.7127, -74.0059
之后我们将这些转换为要绘制的x
和y
坐标。
xpt, ypt=m(NYClon, NYClat)
注意这里,我们现在已经将坐标顺序翻转为lon, lat
(纬度,经度)。 坐标通常以lat, lon
顺序给出。 然而,在图形中,lat, long
转换为y, x
,我们显然不需要。 在某些时候,你必须翻转它们。 不要忘记这部分!
最后,我们可以绘制如下的坐标:
m.plot(xpt, ypt, 'c*', markersize=15)
这个图表上有一个青色的星,大小为 15。更多标记类型请参阅:Matplotlib 标记文档。
接下来,让我们再画一个位置,洛杉矶,加利福尼亚:
LAlat, LAlon=34.05, -118.25xpt, ypt=m(LAlon, LAlat)m.plot(xpt, ypt, 'g^', markersize=15)
这次我们画出一个绿色三角,执行代码会生成:
[图片上传失败...(image-8e5eff-1558445064342)]
如果我们想连接这些图块怎么办?原来,我们可以像其它 Matplotlib 图表那样实现它。
首先,我们将那些xpt
和ypt
坐标保存到列表,类似这样的东西:
xs=[]ys=[]NYClat, NYClon=40.7127, -74.0059xpt, ypt=m(NYClon, NYClat)xs.append(xpt)ys.append(ypt)m.plot(xpt, ypt, 'c*', markersize=15)LAlat, LAlon=34.05, -118.25xpt, ypt=m(LAlon, LAlat)xs.append(xpt)ys.append(ypt)m.plot(xpt, ypt, 'g^', markersize=15)m.plot(xs, ys, color='r', linewidth=3, label='Flight 98')
会生成:
image太棒了。有时我们需要以圆弧连接图上的两个坐标。如何实现呢?
m.drawgreatcircle(NYClon, NYClat, LAlon, LAlat, color='c', linewidth=3, label='Arc')
我们的完整代码为:
from mpl_toolkits.basemap import Basemapimport matplotlib.pyplot as pltm=Basemap(projection='mill', llcrnrlat=25, llcrnrlon=-130, urcrnrlat=50, urcrnrlon=-60, resolution='l')m.drawcoastlines()m.drawcountries(linewidth=2)m.drawstates(color='b')#m.drawcounties(color='darkred')#m.fillcontinents()#m.etopo()#m.bluemarble()xs=[]ys=[]NYClat, NYClon=40.7127, -74.0059xpt, ypt=m(NYClon, NYClat)xs.append(xpt)ys.append(ypt)m.plot(xpt, ypt, 'c*', markersize=15)LAlat, LAlon=34.05, -118.25xpt, ypt=m(LAlon, LAlat)xs.append(xpt)ys.append(ypt)m.plot(xpt, ypt, 'g^', markersize=15)m.plot(xs, ys, color='r', linewidth=3, label='Flight 98')m.drawgreatcircle(NYClon, NYClat, LAlon, LAlat, color='c', linewidth=3, label='Arc')plt.legend(loc=4)plt.title('Basemap Tutorial')plt.show()
结果为:
image这就是Basemap
的全部了,下一章关于 Matplotlib 的 3D 绘图。
您好,欢迎阅读 Matplotlib 教程中的 3D 绘图。 Matplotlib 已经内置了三维图形,所以我们不需要再下载任何东西。 首先,我们需要引入一些完整的模块:
from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as plt
使用axes3d
是因为它需要不同种类的轴域,以便在三维中实际绘制一些东西。 下面:
fig=plt.figure()ax1=fig.add_subplot(111, projection='3d')
在这里,我们像通常一样定义图形,然后我们将ax1
定义为通常的子图,只是这次使用 3D 投影。 我们需要这样做,以便提醒 Matplotlib 我们要提供三维数据。
现在让我们创建一些 3D 数据:
x=[1,2,3,4,5,6,7,8,9,10]y=[5,6,7,8,2,5,6,3,7,2]z=[1,2,6,3,2,7,3,3,7,2]
接下来,我们绘制它。 首先,让我们展示一个简单的线框示例:
ax1.plot_wireframe(x,y,z)
最后:
ax1.set_xlabel('x axis')ax1.set_ylabel('y axis')ax1.set_zlabel('z axis')plt.show()
我们完整的代码是:
from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as pltfrom matplotlib import stylestyle.use('fivethirtyeight')fig=plt.figure()ax1=fig.add_subplot(111, projection='3d')x=[1,2,3,4,5,6,7,8,9,10]y=[5,6,7,8,2,5,6,3,7,2]z=[1,2,6,3,2,7,3,3,7,2]ax1.plot_wireframe(x,y,z)ax1.set_xlabel('x axis')ax1.set_ylabel('y axis')ax1.set_zlabel('z axis')plt.show()
结果为(包括所用的样式):
image这些 3D 图形可以进行交互。 首先,您可以使用鼠标左键单击并拖动来移动图形。 您还可以使用鼠标右键单击并拖动来放大或缩小。
第三十章 3D 散点图欢迎阅读另一个 3D Matplotlib 教程,会涉及如何绘制三维散点图。
绘制 3D 散点图非常类似于通常的散点图以及 3D 线框图。
一个简单示例:
from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as pltfrom matplotlib import stylestyle.use('ggplot')fig=plt.figure()ax1=fig.add_subplot(111, projection='3d')x=[1,2,3,4,5,6,7,8,9,10]y=[5,6,7,8,2,5,6,3,7,2]z=[1,2,6,3,2,7,3,3,7,2]x2=[-1,-2,-3,-4,-5,-6,-7,-8,-9,-10]y2=[-5,-6,-7,-8,-2,-5,-6,-3,-7,-2]z2=[1,2,6,3,2,7,3,3,7,2]ax1.scatter(x, y, z, c='g', marker='o')ax1.scatter(x2, y2, z2, c='r', marker='o')ax1.set_xlabel('x axis')ax1.set_ylabel('y axis')ax1.set_zlabel('z axis')plt.show()
结果为:
image要记住你可以修改这些绘图的大小和标记,就像通常的散点图那样。
第三十一章 3D 条形图在这个 Matplotlib 教程中,我们要介绍 3D 条形图。 3D 条形图是非常独特的,因为它允许我们绘制多于 3 个维度。 不,你不能超过第三个维度来绘制,但你可以绘制多于 3 个维度。
对于条形图,你需要拥有条形的起点,条形的高度和宽度。 但对于 3D 条形图,你还有另一个选项,就是条形的深度。 大多数情况下,条形图从轴上的条形平面开始,但是你也可以通过打破此约束来添加另一个维度。 然而,我们会让它非常简单:
from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as pltimport numpy as npfrom matplotlib import stylestyle.use('ggplot')fig=plt.figure()ax1=fig.add_subplot(111, projection='3d')x3=[1,2,3,4,5,6,7,8,9,10]y3=[5,6,7,8,2,5,6,3,7,2]z3=np.zeros(10)dx=np.ones(10)dy=np.ones(10)dz=[1,2,3,4,5,6,7,8,9,10]ax1.bar3d(x3, y3, z3, dx, dy, dz)ax1.set_xlabel('x axis')ax1.set_ylabel('y axis')ax1.set_zlabel('z axis')plt.show()
注意这里,我们必须定义x
、y
和z
,然后是 3 个维度的宽度、高度和深度。 这会生成:
欢迎阅读最后的 Matplotlib 教程。 在这里我们将整理整个系列,并显示一个稍微更复杂的 3D 线框图:
from mpl_toolkits.mplot3d import axes3dimport matplotlib.pyplot as pltimport numpy as npfrom matplotlib import stylestyle.use('ggplot')fig=plt.figure()ax1=fig.add_subplot(111, projection='3d')x, y, z=axes3d.get_test_data()print(axes3d.__file__)ax1.plot_wireframe(x,y,z, rstride=3, cstride=3)ax1.set_xlabel('x axis')ax1.set_ylabel('y axis')ax1.set_zlabel('z axis')plt.show()
image如果你从一开始就关注这个教程的话,那么你已经学会了 Matplotlib 提供的大部分内容。 你可能不相信,但Matplotlib 仍然可以做很多其他的事情! 请继续学习,你可以随时访问 Matplotlib.org,并查看示例和图库页面。
如果你发现自己大量使用 Matplotlib,请考虑捐助给 John Hunter Memorial 基金。
注:空间曲面的画法
# 二次抛物面 z=x^2 + y^2x=np.linspace(-10, 10, 101)y=xx, y=np.meshgrid(x, y)z=x ** 2 + y ** 2ax=plot.subplot(111, projection='3d')ax.plot_wireframe(x, y, z)plot.show()
image
# 半径为 1 的球t=np.linspace(0, np.pi * 2, 100)s=np.linspace(0, np.pi, 100)t, s=np.meshgrid(t, s)x=np.cos(t) * np.sin(s)y=np.sin(t) * np.sin(s)z=np.cos(s)ax=plot.subplot(111, projection='3d')ax.plot_wireframe(x, y, z)plot.show()
image
参考链接
导读:获取数据之后,而不知道如何查看数据,用途还是有限的。幸好,我们有Matplotlib!
Matplotlib 是基于 NumPy 数组构建的多平台数据可视化库。它是John Hunter 在2002年构想的,原本的设计是给 IPython 打补丁,让命令行中也可以有交互式的 MATLAB 风格的画图工具。
在近些年,已经出现了更新更好的工具最终替代了 Matplotlib(比如 R 语言中的ggplot和ggvis), 但 Matplotlib 依旧是一个经过良好测试的、跨平台的图形引擎。
作者:迈克尔·贝耶勒(Michael Beyeler)
如需转载请联系华章科技
01 引入 Matplotlib
如果已安装Anaconda Python版本,就已经安装好了可以使用的 Matplotlib。否则,可能要访问官网并从中获取安装说明:
http://matplotlib.org
正如使用np作为 NumPy 的缩写,我们将使用一些标准的缩写来表示 Matplotlib 的引入:
In [1]: import matplotlib as mplIn [2]: import matplotlib.pyplot as plt
在本书中,plt接口会被频繁使用。
02 生成一个简单的绘图
让我们创建第一个绘图。
假设想要画出正弦函数sin(x)的线性图。得到函数在x坐标轴上0≤x<10内所有点的值。我们将使用 NumPy 中的 linspace 函数来在x坐标轴上创建一个从0到10的线性空间,以及100个采样点:
In [3]: import numpy as npIn [4]: x=np.linspace(0, 10, 100)
可以使用 NumPy 中的sin函数得到所有x点的值,并通过调用plt中的plot函数把结果画出来:
In [5]: plt.plot(x, np.sin(x))
你亲自尝试了吗?发生了什么吗?有没有什么东西出现?
实际情况是,取决于你在哪里运行脚本,可能无法看到任何东西。有下面几种可能性:
1. 从.py脚本中绘图
如果从一个脚本中运行 Matplotlib,需要加上下面的这行调用:
plt.show()
在脚本末尾调用这个函数,你的绘图就会出现!
2. 从 IPython shell 中绘图
这实际上是交互式地执行Matplotlib最方便的方式。为了让绘图出现,需要在启动 IPython 后使用所谓的%matplotlib魔法命令。
In [1]: %matplotlibUsing matplotlib backend: TkAggIn [2]: import matplotlib.pyplot as plt
接下来,无须每次调用plt.show()函数,所有的绘图将会自动出现。
3. 从 Jupyter Notebook 中绘图
如果你是从基于浏览器的 Jupyter Notebook 中看这段代码,需要使用同样的%matplotlib魔法命令。然而,也可以直接在notebook中嵌入图形,这会有两种输出选项:
%matplotlib notebook 将会把交互式的绘图嵌入到notebook中%matplotlib inline 将会把绘图的静态图像嵌入到notebook中在本书中,将会使用inline选项:
In [6]: %matplotlib inline
现在再次尝试一下:
In [7]: plt.plot(x, np.sin(x))Out[7]: [<matplotlib.lines.Line2D at 0x7f3aac426eb8>]
上面的命令会得到下面的绘图输出结果:
▲使用 Matplotlib 绘制正弦函数图像
如果想要把绘图保存下来留作以后使用,可以直接在 IPython 或者 Jupyter Notebook 使用下面的命令保存:
In [8]: plt.savefig('figures/02.03-sine.png')
仅需要确保你使用了支持的文件后缀,比如.jpg、.png、.tif、.svg、.eps或者.pdf。
Tips:可以在引入Matplotlib后通过运行plt.style.use(style_name)来修改绘图的风格。所有可用的风格在plt.style.available中列出。比如,尝试使用plt.style.use('fivethirtyeight')、plt.style.use('ggplot')或者plt.style.use('seaborn-dark')。为了更好玩,可以运行plt.xkcd(),然后尝试绘制一些别的图形。
03 可视化外部数据集的数据
作为本章最后一个测试,让我们对外部数据集进行可视化,比如scikit-learn中的数字数据集。
为此,需要三个可视化工具:
scikit-learn用于获取实际数据NumPy 用于数据再加工Matplotlib那么开始引入这些包吧:
In [1]: import numpy as np... from sklearn import datasets... import matplotlib.pyplot as plt... % matplotlib inline
第一步是载入实际数据:
In [2]: digits=datasets.load_digits()
如果没记错的话,digits应该有两个不同的数据域:data域包含了真正的图像数据,target域包含了图像的标签。相对于相信我们的记忆,我们还是应该对digits稍加探索。输入它的名字,添加一个点号,然后按Tab键:digits.<TAB>,这个操作将向我们展示digits也包含了一些其他的域,比如一个名为images的域。images和data这两个域,似乎简单从形状上就可以区分。
In [3]: print(digits.data.shape)... print(digits.images.shape)Out[3]: (1797, 64)... (1797, 8, 8)
两种情况中,第一维对应的都是数据集中的图像数量。然而,data中所有像素都在一个大的向量中排列,而images保留了各个图像8×8的空间排列。
因此,如果想要绘制出一副单独的图像,使用images将更加合适。首先,使用NumPy的数组切片从数据集中获取一幅图像:
In [4]: img=digits.images[0, :, :]
这里是从1797个元素的数组中获取了它的第一行数据,这行数据对应的是8×8=64个像素。下面就可以使用plt中的imshow函数来绘制这幅图像:
In [5]: plt.imshow(img, cmap='gray')Out[5]: <matplotlib.image.AxesImage at 0x7efcd27f30f0>
上面的命令得到下面的输出:
▲数字数据集中的一个图像样例
此外,这里也使用cmap参数指定了一个颜色映射。默认情况下,Matplotlib 使用MATLAB默认的颜色映射jet。然而,在灰度图像的情况下,gray颜色映射更有效。
最后,可以使用plt的subplot函数绘制全部数字的样例。subplot函数与MATLAB中的函数一样,需要指定行数、列数以及当前的子绘图索引(从1开始计算)。我们将使用for 循环在数据集中迭代出前十张图像,每张图像都分配到一个单独的子绘图中。
In [6]: for image_index in range(10):... # 图像按0开始索引,子绘图按1开始索引... subplot_index=image_index + 1... plt.subplot(2, 5, subplot_index)... plt.imshow(digits.images[image_index, :, :], cmap='gray')
这会得到下面的输出结果:
▲数字数据集中的十幅样例图像
Tips:另一个拥有各种数据集资源的是我的母校——加利福尼亚大学欧文分校的机器学习仓库:
http://archive.ics.uci.edu/ml
关于作者:Michael Beyeler,华盛顿大学神经工程和数据科学专业的博士后,主攻仿生视觉计算模型,用以为盲人植入人工视网膜(仿生眼睛),改善盲人的视觉体验。 他的工作属于神经科学、计算机工程、计算机视觉和机器学习的交叉领域。同时他也是多个开源项目的积极贡献者。
本文摘编自《机器学习:使用OpenCV和Python进行智能图像处理》,经出版方授权发布。
延伸阅读《机器学习》
推荐语:OpenCV是一个综合了经典和先进计算机视觉、机器学习算法的开源库。通过与Python Anaconda版本结合,你就可以获取你所需要的所有开源计算库。 本书首先介绍分类和回归等统计学习的基本概念,然后详细讲解决策树、支持向量机和贝叶斯网络等算法,以及如何把它们与其他OpenCV函数结合。
上一篇:【赞美日记8】
发表评论